mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(gpu): support keyswitch 64/32
This commit is contained in:
committed by
Andrei Stoian
parent
14d49f0891
commit
78d1ce18c1
@@ -44,7 +44,7 @@ mod decomposer;
|
||||
mod iter;
|
||||
mod term;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
pub(crate) mod tests;
|
||||
|
||||
/// The level of a given term of a decomposition.
|
||||
///
|
||||
|
||||
@@ -12,7 +12,7 @@ use std::fmt::Debug;
|
||||
|
||||
pub const NB_TESTS: usize = 10_000_000;
|
||||
|
||||
fn valid_decomposers<T: UnsignedInteger>() -> Vec<SignedDecomposer<T>> {
|
||||
pub(crate) fn valid_decomposers<T: UnsignedInteger>() -> Vec<SignedDecomposer<T>> {
|
||||
let mut valid_decomposers = vec![];
|
||||
for base_log in (1..T::BITS).map(DecompositionBaseLog) {
|
||||
for level_count in (1..T::BITS).map(DecompositionLevelCount) {
|
||||
|
||||
@@ -46,7 +46,7 @@ impl<T: UnsignedInteger> TUniform<T> {
|
||||
/// representation of integers.
|
||||
pub const fn try_new(bound_log2: u32) -> Result<Self, &'static str> {
|
||||
if (bound_log2 + 2) as usize > T::BITS {
|
||||
return Err("Cannot create TUnfirorm: \
|
||||
return Err("Cannot create TUniform: \
|
||||
bound_log2 + 2 is greater than the current type's bit width");
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@ use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use crate::core_crypto::gpu::vec::CudaVec;
|
||||
use crate::core_crypto::gpu::{
|
||||
keyswitch_async, keyswitch_async_gemm, scratch_cuda_keyswitch_gemm_64, CudaStreams,
|
||||
cleanup_cuda_keyswitch_gemm_64, cuda_closest_representable_64, keyswitch_async,
|
||||
keyswitch_async_gemm, scratch_cuda_keyswitch_gemm_64, CudaStreams,
|
||||
};
|
||||
use crate::core_crypto::prelude::UnsignedInteger;
|
||||
use std::cmp::min;
|
||||
use tfhe_cuda_backend::bindings::cleanup_cuda_keyswitch_gemm_64;
|
||||
use tfhe_cuda_backend::ffi;
|
||||
|
||||
/// # Safety
|
||||
@@ -14,10 +14,10 @@ use tfhe_cuda_backend::ffi;
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||
/// be dropped until stream is synchronised
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
|
||||
pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar, KSKScalar>(
|
||||
lwe_keyswitch_key: &CudaLweKeyswitchKey<KSKScalar>,
|
||||
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
|
||||
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
|
||||
output_lwe_ciphertext: &mut CudaLweCiphertextList<KSKScalar>,
|
||||
input_indexes: &CudaVec<Scalar>,
|
||||
output_indexes: &CudaVec<Scalar>,
|
||||
uses_trivial_indices: bool,
|
||||
@@ -25,6 +25,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
use_gemm_ks: bool,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
KSKScalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension()
|
||||
@@ -91,6 +92,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
);
|
||||
|
||||
if use_gemm_ks {
|
||||
// Scratch allocations uses input LWE dtype for buffer size
|
||||
cuda_scratch_keyswitch_lwe_ciphertext_async::<Scalar>(
|
||||
streams,
|
||||
std::ptr::addr_of_mut!(ks_tmp_buffer),
|
||||
@@ -100,6 +102,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
true,
|
||||
);
|
||||
|
||||
// Gemm KS can KS with input LWE dtype Scalar to output LWE dtype KSKScalar
|
||||
keyswitch_async_gemm(
|
||||
streams,
|
||||
&mut output_lwe_ciphertext.0.d_vec,
|
||||
@@ -139,10 +142,10 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
|
||||
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
|
||||
pub fn cuda_keyswitch_lwe_ciphertext<Scalar, KSKScalar>(
|
||||
lwe_keyswitch_key: &CudaLweKeyswitchKey<KSKScalar>,
|
||||
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
|
||||
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
|
||||
output_lwe_ciphertext: &mut CudaLweCiphertextList<KSKScalar>,
|
||||
input_indexes: &CudaVec<Scalar>,
|
||||
output_indexes: &CudaVec<Scalar>,
|
||||
uses_trivial_indices: bool,
|
||||
@@ -150,6 +153,7 @@ pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
|
||||
use_gemm_ks: bool,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
KSKScalar: UnsignedInteger,
|
||||
{
|
||||
unsafe {
|
||||
cuda_keyswitch_lwe_ciphertext_async(
|
||||
@@ -207,3 +211,25 @@ pub unsafe fn cleanup_cuda_keyswitch_async<Scalar>(
|
||||
allocate_gpu_memory,
|
||||
);
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||
/// be dropped until stream is synchronized
|
||||
pub unsafe fn cuda_closest_representable<Scalar>(
|
||||
streams: &CudaStreams,
|
||||
input: &CudaVec<Scalar>,
|
||||
output: &mut CudaVec<Scalar>,
|
||||
base_log: u32,
|
||||
level_count: u32,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
cuda_closest_representable_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
input.as_c_ptr(0),
|
||||
output.as_mut_c_ptr(0),
|
||||
base_log,
|
||||
level_count,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::commons::test_tools::any_uint;
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
|
||||
use crate::core_crypto::gpu::{cuda_keyswitch_lwe_ciphertext, CudaStreams};
|
||||
use crate::core_crypto::gpu::{
|
||||
cuda_closest_representable, cuda_keyswitch_lwe_ciphertext, CudaStreams,
|
||||
};
|
||||
use crate::core_crypto::prelude::misc::check_encrypted_content_respects_mod;
|
||||
use itertools::Itertools;
|
||||
use rand::seq::SliceRandom;
|
||||
@@ -61,6 +64,11 @@ fn lwe_encrypt_ks_decrypt_custom_mod_mb<Scalar: UnsignedTorus + CastFrom<usize>>
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
// Use for both the Multi-Bit and Classic PBS setting.
|
||||
// Tests GEMM and Classic KS:
|
||||
// - tests that keyswitched LWE is decrypted correctly
|
||||
// - tests that GEMM and Classic KS are bit-wise equivalent
|
||||
// - tests that only a subset of LWEs can be keyswitched
|
||||
fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
|
||||
lwe_dimension: LweDimension,
|
||||
lwe_noise_distribution: DynamicDistribution<Scalar>,
|
||||
@@ -166,6 +174,7 @@ fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize
|
||||
} else {
|
||||
num_blocks
|
||||
};
|
||||
|
||||
let lwe_indexes_usize = (0..num_blocks).collect_vec();
|
||||
let mut lwe_indexes = lwe_indexes_usize.clone();
|
||||
|
||||
@@ -227,20 +236,35 @@ fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize
|
||||
assert_eq!(output_ct_list_gpu.lwe_ciphertext_count().0, num_blocks);
|
||||
// The output has `n_blocks` LWEs but only some are actually set - those
|
||||
// that correspond to output indices. We loop over all LWEs in the output buffer
|
||||
let output_ct_list_cpu = output_ct_list_gpu_gemm.to_lwe_ciphertext_list(&stream);
|
||||
let output_ct_list_cpu_gemm = output_ct_list_gpu_gemm.to_lwe_ciphertext_list(&stream);
|
||||
output_ct_list_gpu
|
||||
.to_lwe_ciphertext_list(&stream)
|
||||
.iter()
|
||||
.zip(0..num_blocks)
|
||||
.for_each(|(lwe_ct_out, i)| {
|
||||
let lwe_ct_out_gemm = output_ct_list_cpu_gemm.get(i);
|
||||
|
||||
let tmp_classical = lwe_ct_out.into_container();
|
||||
let tmp_gemm = lwe_ct_out_gemm.into_container();
|
||||
|
||||
// Compare bitwise the output of classical KS and GEMM KS
|
||||
for (v1, v2) in tmp_classical.iter().zip(tmp_gemm.iter()) {
|
||||
assert_eq!(*v1, *v2);
|
||||
}
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&lwe_ct_out,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out);
|
||||
// Check GEMM vs Classical bitwise equivalent
|
||||
let tmp_gemm = lwe_ct_out_gemm.into_container();
|
||||
for (v1, v2) in tmp_classical.iter().zip(tmp_gemm.iter()) {
|
||||
assert_eq!(v1, v2);
|
||||
}
|
||||
|
||||
let lwe_ct_out_gemm = output_ct_list_cpu.get(i);
|
||||
// Check GEMM & Classical KS decrypt to reference value
|
||||
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out);
|
||||
let decrypted_gemm = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out_gemm);
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
|
||||
@@ -253,5 +277,306 @@ fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn lwe_encrypt_ks_decrypt_ks32_common<
|
||||
Scalar: UnsignedTorus + CastFrom<usize> + CastInto<KSKScalar> + CastFrom<KSKScalar>,
|
||||
KSKScalar: UnsignedTorus + CastFrom<Scalar>,
|
||||
>(
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
lwe_noise_distribution: DynamicDistribution<Scalar>,
|
||||
ks_decomp_base_log: DecompositionBaseLog,
|
||||
ks_decomp_level_count: DecompositionLevelCount,
|
||||
message_modulus_log: MessageModulusLog,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) {
|
||||
let input_encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
|
||||
let output_ciphertext_modulus = CiphertextModulus::<KSKScalar>::new_native();
|
||||
let output_encoding_with_padding = get_encoding_with_padding(output_ciphertext_modulus);
|
||||
|
||||
let lwe_noise_distribution_u32 = match lwe_noise_distribution {
|
||||
DynamicDistribution::Gaussian(gaussian_lwe_noise_distribution) => {
|
||||
DynamicDistribution::<KSKScalar>::new_gaussian(
|
||||
gaussian_lwe_noise_distribution.standard_dev(),
|
||||
)
|
||||
}
|
||||
DynamicDistribution::TUniform(uniform_lwe_noise_distribution) => {
|
||||
DynamicDistribution::<KSKScalar>::new_t_uniform(
|
||||
uniform_lwe_noise_distribution.bound_log2(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let input_msg_modulus = Scalar::ONE << message_modulus_log.0;
|
||||
let output_msg_modulus = KSKScalar::ONE << message_modulus_log.0;
|
||||
|
||||
let input_delta = input_encoding_with_padding / input_msg_modulus;
|
||||
let output_delta = output_encoding_with_padding / output_msg_modulus;
|
||||
|
||||
let stream = CudaStreams::new_single_gpu(GpuIndex::new(0));
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
let mut msg = input_msg_modulus;
|
||||
|
||||
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key::<Scalar, _>(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let big_lwe_sk_u32 = LweSecretKey::from_container(
|
||||
glwe_sk
|
||||
.as_ref()
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| x.cast_into())
|
||||
.collect::<Vec<KSKScalar>>(),
|
||||
);
|
||||
|
||||
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
|
||||
|
||||
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
&big_lwe_sk_u32,
|
||||
&lwe_sk,
|
||||
ks_decomp_base_log,
|
||||
ks_decomp_level_count,
|
||||
lwe_noise_distribution_u32,
|
||||
output_ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&ksk_big_to_small,
|
||||
output_ciphertext_modulus
|
||||
));
|
||||
|
||||
let d_ksk_big_to_small =
|
||||
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&ksk_big_to_small, &stream);
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
msg = msg.wrapping_sub(Scalar::ONE);
|
||||
for _ in 0..NB_TESTS {
|
||||
let plaintext = Plaintext(msg * input_delta);
|
||||
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk, //64b
|
||||
plaintext,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&ct,
|
||||
ciphertext_modulus
|
||||
));
|
||||
let mut output_ct_ref = LweCiphertext::new(
|
||||
KSKScalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
output_ciphertext_modulus,
|
||||
);
|
||||
|
||||
keyswitch_lwe_ciphertext_with_scalar_change(&ksk_big_to_small, &ct, &mut output_ct_ref);
|
||||
// 32b, 64b, 32b
|
||||
|
||||
let decrypted_cpu = decrypt_lwe_ciphertext(&lwe_sk, &output_ct_ref);
|
||||
let decoded_cpu: KSKScalar =
|
||||
round_decode(decrypted_cpu.0, output_delta) % output_msg_modulus;
|
||||
assert_eq!(msg, decoded_cpu.cast_into());
|
||||
|
||||
let d_ct = CudaLweCiphertextList::from_lwe_ciphertext(&ct, &stream);
|
||||
let mut d_output_ct = CudaLweCiphertextList::new(
|
||||
ksk_big_to_small.output_key_lwe_dimension(),
|
||||
LweCiphertextCount(1),
|
||||
output_ciphertext_modulus,
|
||||
&stream,
|
||||
);
|
||||
let mut d_output_ct_gemm = CudaLweCiphertextList::new(
|
||||
ksk_big_to_small.output_key_lwe_dimension(),
|
||||
LweCiphertextCount(1),
|
||||
output_ciphertext_modulus,
|
||||
&stream,
|
||||
);
|
||||
let num_blocks = d_ct.0.lwe_ciphertext_count.0;
|
||||
let lwe_indexes_usize = (0..num_blocks).collect_vec();
|
||||
let lwe_indexes = lwe_indexes_usize
|
||||
.iter()
|
||||
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
|
||||
.collect_vec();
|
||||
let mut d_input_indexes =
|
||||
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
|
||||
let mut d_output_indexes =
|
||||
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
|
||||
unsafe { d_input_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
|
||||
unsafe { d_output_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
|
||||
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
&d_ksk_big_to_small,
|
||||
&d_ct,
|
||||
&mut d_output_ct,
|
||||
&d_input_indexes,
|
||||
&d_output_indexes,
|
||||
true,
|
||||
&stream,
|
||||
false,
|
||||
);
|
||||
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
&d_ksk_big_to_small,
|
||||
&d_ct,
|
||||
&mut d_output_ct_gemm,
|
||||
&d_input_indexes,
|
||||
&d_output_indexes,
|
||||
true,
|
||||
&stream,
|
||||
true,
|
||||
);
|
||||
|
||||
let output_ct = d_output_ct.into_lwe_ciphertext(&stream);
|
||||
|
||||
let tmp = output_ct.clone().into_container();
|
||||
let tmp_cpu = output_ct_ref.clone().into_container();
|
||||
for (v1, v2) in tmp.iter().zip(tmp_cpu.iter()) {
|
||||
assert_eq!(v1, v2);
|
||||
}
|
||||
|
||||
let output_ct_gemm = d_output_ct_gemm.into_lwe_ciphertext(&stream);
|
||||
|
||||
let tmp = output_ct_gemm.clone().into_container();
|
||||
let tmp_cpu = output_ct_ref.clone().into_container();
|
||||
for (v1, v2) in tmp.iter().zip(tmp_cpu.iter()) {
|
||||
assert_eq!(v1, v2);
|
||||
}
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&output_ct,
|
||||
output_ciphertext_modulus
|
||||
));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &output_ct);
|
||||
let decoded = round_decode(decrypted.0, output_delta) % output_msg_modulus;
|
||||
assert_eq!(msg, decoded.cast_into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lwe_encrypt_ks_decrypt_custom_mod_ks32<
|
||||
Scalar: UnsignedTorus + CastFrom<usize> + CastInto<u32> + CastFrom<u32>,
|
||||
>(
|
||||
params: &MultiBitTestKS32Params<Scalar>,
|
||||
) where
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
lwe_encrypt_ks_decrypt_ks32_common::<Scalar, u32>(
|
||||
params.lwe_dimension,
|
||||
params.glwe_dimension,
|
||||
params.polynomial_size,
|
||||
params.lwe_noise_distribution,
|
||||
params.ks_base_log,
|
||||
params.ks_level,
|
||||
params.message_modulus_log,
|
||||
params.ciphertext_modulus,
|
||||
);
|
||||
}
|
||||
|
||||
fn test_util_closest_representable_on_gpu(
|
||||
value: u64,
|
||||
base_log: DecompositionBaseLog,
|
||||
level_count: DecompositionLevelCount,
|
||||
) -> u64 {
|
||||
let stream = CudaStreams::new_single_gpu(GpuIndex::new(0));
|
||||
|
||||
let h_input: Vec<u64> = vec![value];
|
||||
|
||||
let mut d_input = unsafe { CudaVec::<u64>::new_async(1, &stream, 0) };
|
||||
unsafe { d_input.copy_from_cpu_async(&h_input, &stream, 0) };
|
||||
|
||||
let mut d_output = unsafe { CudaVec::<u64>::new_async(1, &stream, 0) };
|
||||
|
||||
unsafe {
|
||||
cuda_closest_representable(
|
||||
&stream,
|
||||
&d_input,
|
||||
&mut d_output,
|
||||
base_log.0 as u32,
|
||||
level_count.0 as u32,
|
||||
);
|
||||
}
|
||||
|
||||
let mut h_output: Vec<u64> = vec![0];
|
||||
unsafe {
|
||||
d_output.copy_to_cpu_async(&mut h_output, &stream, 0);
|
||||
}
|
||||
|
||||
stream.synchronize();
|
||||
|
||||
*h_output.first().unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_closest_representable_gpu() {
|
||||
let base_log = DecompositionBaseLog(17);
|
||||
let level_count = DecompositionLevelCount(3);
|
||||
let decomposer = SignedDecomposer::new(base_log, level_count);
|
||||
// This value triggers a negative state at the start of the decomposition, invalid code using
|
||||
// logic shift will wrongly compute an intermediate value by not keeping the sign of the
|
||||
// state on the last level if base_log * (level_count + 1) > Scalar::BITS, the logic shift will
|
||||
// shift in 0s instead of the 1s to keep the sign information
|
||||
let val: u64 = 0x8000_00e3_55b0_c827;
|
||||
|
||||
let rounded = decomposer.closest_representable(val);
|
||||
|
||||
let recomp = decomposer.recompose(decomposer.decompose(val)).unwrap();
|
||||
|
||||
let rounded_gpu = test_util_closest_representable_on_gpu(val, base_log, level_count);
|
||||
|
||||
assert_eq!(rounded, recomp);
|
||||
assert_eq!(rounded_gpu, rounded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_to_closest_representable_gpu() {
|
||||
let runs_per_decomposer = 100;
|
||||
|
||||
let valid_decomposers =
|
||||
crate::core_crypto::commons::math::decomposition::tests::valid_decomposers::<u64>();
|
||||
|
||||
for decomposer in valid_decomposers {
|
||||
// Checks that the closest representable computed on GPU is the same as on CPU
|
||||
for _ in 0..runs_per_decomposer {
|
||||
let val = any_uint::<u64>();
|
||||
|
||||
let rounded = test_util_closest_representable_on_gpu(
|
||||
val,
|
||||
decomposer.base_log(),
|
||||
decomposer.level_count(),
|
||||
);
|
||||
|
||||
let epsilon =
|
||||
(1u64 << (64 - (decomposer.base_log * decomposer.level_count) - 1)) / 2u64;
|
||||
// Adding/removing an epsilon should not change the closest representable
|
||||
assert_eq!(
|
||||
rounded,
|
||||
decomposer.closest_representable(rounded.wrapping_add(epsilon))
|
||||
);
|
||||
assert_eq!(
|
||||
rounded,
|
||||
decomposer.closest_representable(rounded.wrapping_sub(epsilon))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create_gpu_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod);
|
||||
create_gpu_multi_bit_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod_mb);
|
||||
create_gpu_multi_bit_ks32_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod_ks32);
|
||||
|
||||
@@ -13,6 +13,7 @@ mod lwe_programmable_bootstrapping;
|
||||
mod lwe_programmable_bootstrapping_128;
|
||||
mod modulus_switch;
|
||||
mod noise_distribution;
|
||||
mod params;
|
||||
|
||||
pub struct CudaPackingKeySwitchKeys<Scalar: UnsignedInteger> {
|
||||
pub lwe_sk: LweSecretKey<Vec<Scalar>>,
|
||||
@@ -20,6 +21,24 @@ pub struct CudaPackingKeySwitchKeys<Scalar: UnsignedInteger> {
|
||||
pub pksk: CudaLwePackingKeyswitchKey<Scalar>,
|
||||
}
|
||||
|
||||
pub const MULTI_BIT_2_2_2_KS32_PARAMS: MultiBitTestKS32Params<u64> = MultiBitTestKS32Params {
|
||||
lwe_dimension: LweDimension(920),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(2048),
|
||||
lwe_noise_distribution: DynamicDistribution::new_t_uniform(13),
|
||||
glwe_noise_distribution: DynamicDistribution::new_t_uniform(17),
|
||||
pbs_base_log: DecompositionBaseLog(22),
|
||||
pbs_level: DecompositionLevelCount(1),
|
||||
ks_base_log: DecompositionBaseLog(3),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
message_modulus_log: MessageModulusLog(2),
|
||||
log2_p_fail: -134.345,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
grouping_factor: LweBskGroupingFactor(4),
|
||||
deterministic_execution: false,
|
||||
};
|
||||
|
||||
// Macro to generate tests for all parameter sets
|
||||
macro_rules! create_gpu_parameterized_test{
|
||||
($name:ident { $($param:ident),* }) => {
|
||||
@@ -60,6 +79,27 @@ macro_rules! create_gpu_multi_bit_parameterized_test{
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! create_gpu_multi_bit_ks32_parameterized_test{
|
||||
($name:ident { $($param:ident),* }) => {
|
||||
::paste::paste! {
|
||||
$(
|
||||
#[test]
|
||||
fn [<test_gpu_ $name _ $param:lower>]() {
|
||||
$name(&$param)
|
||||
}
|
||||
)*
|
||||
}
|
||||
};
|
||||
($name:ident)=> {
|
||||
create_gpu_multi_bit_ks32_parameterized_test!($name
|
||||
{
|
||||
MULTI_BIT_2_2_2_KS32_PARAMS
|
||||
});
|
||||
};
|
||||
}
|
||||
use crate::core_crypto::gpu::algorithms::test::params::MultiBitTestKS32Params;
|
||||
use crate::core_crypto::gpu::lwe_packing_keyswitch_key::CudaLwePackingKeyswitchKey;
|
||||
use {create_gpu_multi_bit_parameterized_test, create_gpu_parameterized_test};
|
||||
use {
|
||||
create_gpu_multi_bit_ks32_parameterized_test, create_gpu_multi_bit_parameterized_test,
|
||||
create_gpu_parameterized_test,
|
||||
};
|
||||
|
||||
26
tfhe/src/core_crypto/gpu/algorithms/test/params.rs
Normal file
26
tfhe/src/core_crypto/gpu/algorithms/test/params.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use crate::core_crypto::commons::math::random::{Deserialize, Serialize};
|
||||
use crate::core_crypto::prelude::{
|
||||
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution,
|
||||
GlweDimension, LweBskGroupingFactor, LweDimension, MessageModulusLog, PolynomialSize,
|
||||
UnsignedInteger,
|
||||
};
|
||||
use crate::shortint::EncryptionKeyChoice;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
pub struct MultiBitTestKS32Params<Scalar: UnsignedInteger> {
|
||||
pub lwe_dimension: LweDimension,
|
||||
pub glwe_dimension: GlweDimension,
|
||||
pub polynomial_size: PolynomialSize,
|
||||
pub lwe_noise_distribution: DynamicDistribution<Scalar>,
|
||||
pub glwe_noise_distribution: DynamicDistribution<Scalar>,
|
||||
pub pbs_base_log: DecompositionBaseLog,
|
||||
pub pbs_level: DecompositionLevelCount,
|
||||
pub ks_base_log: DecompositionBaseLog,
|
||||
pub ks_level: DecompositionLevelCount,
|
||||
pub message_modulus_log: MessageModulusLog,
|
||||
pub log2_p_fail: f64,
|
||||
pub ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
pub encryption_key_choice: EncryptionKeyChoice,
|
||||
pub grouping_factor: LweBskGroupingFactor,
|
||||
pub deterministic_execution: bool,
|
||||
}
|
||||
@@ -485,37 +485,58 @@ pub fn get_programmable_bootstrap_multi_bit_size_on_gpu(
|
||||
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
|
||||
/// required
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger>(
|
||||
pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger, KST: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array_out: &mut CudaVec<T>,
|
||||
lwe_array_out: &mut CudaVec<KST>,
|
||||
lwe_out_indexes: &CudaVec<T>,
|
||||
lwe_array_in: &CudaVec<T>,
|
||||
lwe_in_indexes: &CudaVec<T>,
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_lwe_dimension: LweDimension,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<KST>,
|
||||
base_log: DecompositionBaseLog,
|
||||
l_gadget: DecompositionLevelCount,
|
||||
num_samples: u32,
|
||||
ks_tmp_buffer: *const ffi::c_void,
|
||||
uses_trivial_indices: bool,
|
||||
) {
|
||||
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_out_indexes.as_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
lwe_in_indexes.as_c_ptr(0),
|
||||
keyswitch_key.as_c_ptr(0),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_lwe_dimension.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_samples,
|
||||
ks_tmp_buffer,
|
||||
uses_trivial_indices,
|
||||
);
|
||||
if TypeId::of::<KST>() == TypeId::of::<u32>() {
|
||||
cuda_keyswitch_gemm_lwe_ciphertext_vector_64_32(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_out_indexes.as_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
lwe_in_indexes.as_c_ptr(0),
|
||||
keyswitch_key.as_c_ptr(0),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_lwe_dimension.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_samples,
|
||||
ks_tmp_buffer,
|
||||
uses_trivial_indices,
|
||||
);
|
||||
} else if TypeId::of::<KST>() == TypeId::of::<u64>() {
|
||||
cuda_keyswitch_gemm_lwe_ciphertext_vector_64_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_out_indexes.as_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
lwe_in_indexes.as_c_ptr(0),
|
||||
keyswitch_key.as_c_ptr(0),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_lwe_dimension.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_samples,
|
||||
ks_tmp_buffer,
|
||||
uses_trivial_indices,
|
||||
);
|
||||
} else {
|
||||
panic!("Unknown LWE GEMM KS dtype of size {}B", size_of::<KST>());
|
||||
}
|
||||
}
|
||||
|
||||
/// Keyswitch on a vector of LWE ciphertexts. Better for small batches of LWEs
|
||||
@@ -525,33 +546,50 @@ pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger>(
|
||||
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
|
||||
/// required
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn keyswitch_async<T: UnsignedInteger>(
|
||||
pub unsafe fn keyswitch_async<T: UnsignedInteger, KT: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array_out: &mut CudaVec<T>,
|
||||
lwe_array_out: &mut CudaVec<KT>,
|
||||
lwe_out_indexes: &CudaVec<T>,
|
||||
lwe_array_in: &CudaVec<T>,
|
||||
lwe_in_indexes: &CudaVec<T>,
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_lwe_dimension: LweDimension,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<KT>,
|
||||
base_log: DecompositionBaseLog,
|
||||
l_gadget: DecompositionLevelCount,
|
||||
num_samples: u32,
|
||||
) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_out_indexes.as_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
lwe_in_indexes.as_c_ptr(0),
|
||||
keyswitch_key.as_c_ptr(0),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_lwe_dimension.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_samples,
|
||||
);
|
||||
if TypeId::of::<KT>() == TypeId::of::<u32>() {
|
||||
cuda_keyswitch_lwe_ciphertext_vector_64_32(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_out_indexes.as_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
lwe_in_indexes.as_c_ptr(0),
|
||||
keyswitch_key.as_c_ptr(0),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_lwe_dimension.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_samples,
|
||||
);
|
||||
} else if TypeId::of::<KT>() == TypeId::of::<u64>() {
|
||||
cuda_keyswitch_lwe_ciphertext_vector_64_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_out_indexes.as_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
lwe_in_indexes.as_c_ptr(0),
|
||||
keyswitch_key.as_c_ptr(0),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_lwe_dimension.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_samples,
|
||||
);
|
||||
}
|
||||
}
|
||||
/// Convert keyswitch key
|
||||
///
|
||||
|
||||
@@ -487,7 +487,10 @@ impl CudaServerKey {
|
||||
}
|
||||
|
||||
pub fn gpu_indexes(&self) -> &[GpuIndex] {
|
||||
&self.key.key.key_switching_key.d_vec.gpu_indexes
|
||||
match &self.key.key.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(ksk_32) => ksk_32.d_vec.gpu_indexes.as_slice(),
|
||||
CudaDynamicKeyswitchingKey::Standard(std_key) => std_key.d_vec.gpu_indexes.as_slice(),
|
||||
}
|
||||
}
|
||||
pub(in crate::high_level_api) fn re_randomization_cpk_casting_key(
|
||||
&self,
|
||||
@@ -611,6 +614,8 @@ use crate::high_level_api::keys::inner::IntegerServerKeyConformanceParams;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::server_key::CudaDynamicKeyswitchingKey;
|
||||
|
||||
impl ParameterSetConformant for ServerKey {
|
||||
type ParameterSet = IntegerServerKeyConformanceParams;
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable;
|
||||
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
|
||||
use crate::integer::gpu::ciphertext::{CudaRadixCiphertext, CudaVec, KsType, LweDimension};
|
||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{cuda_backend_expand, PBSType};
|
||||
use crate::shortint::ciphertext::CompactCiphertextList;
|
||||
use crate::shortint::parameters::{
|
||||
@@ -404,7 +404,12 @@ impl CudaFlattenedVecCompactCiphertextList {
|
||||
let d_input = &self.d_flattened_vec;
|
||||
let casting_key = key.key_switching_key_material;
|
||||
let sks = key.dest_server_key;
|
||||
let computing_ks_key = &key.dest_server_key.key_switching_key;
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) =
|
||||
&key.dest_server_key.key_switching_key
|
||||
else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let casting_key_type: KsType = casting_key.destination_key.into();
|
||||
|
||||
|
||||
@@ -365,13 +365,17 @@ pub(crate) unsafe fn cuda_backend_scalar_addition_assign<T: UnsignedInteger>(
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_scalar_mul<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_scalar_mul<
|
||||
T: UnsignedInteger,
|
||||
KST: UnsignedInteger,
|
||||
B: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array: &mut CudaRadixCiphertext,
|
||||
decomposed_scalar: &[T],
|
||||
has_at_least_one_set: &[T],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<u64>,
|
||||
keyswitch_key: &CudaVec<KST>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
@@ -1736,13 +1740,17 @@ pub(crate) fn cuda_backend_get_bitop_size_on_gpu(
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_scalar_bitop_assign<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_scalar_bitop_assign<
|
||||
T: UnsignedInteger,
|
||||
KST: UnsignedInteger,
|
||||
B: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
radix_lwe: &mut CudaRadixCiphertext,
|
||||
clear_blocks: &CudaVec<T>,
|
||||
h_clear_blocks: &[T],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<KST>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
@@ -2102,14 +2110,18 @@ pub(crate) fn cuda_backend_get_comparison_size_on_gpu(
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_scalar_comparison<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_scalar_comparison<
|
||||
T: UnsignedInteger,
|
||||
KST: UnsignedInteger,
|
||||
B: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
radix_lwe_out: &mut CudaRadixCiphertext,
|
||||
radix_lwe_in: &CudaRadixCiphertext,
|
||||
scalar_blocks: &CudaVec<T>,
|
||||
h_scalar_blocks: &[T],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<KST>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
@@ -3032,7 +3044,11 @@ pub(crate) unsafe fn cuda_backend_grouped_oprf<B: Numeric>(
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
KST: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
radix_lwe_out: &mut CudaRadixCiphertext,
|
||||
num_blocks_intermediate: u32,
|
||||
@@ -3041,7 +3057,7 @@ pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<T: UnsignedInteger,
|
||||
has_at_least_one_set: &[T],
|
||||
shift: u32,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
key_switching_key: &CudaVec<u64>,
|
||||
key_switching_key: &CudaVec<KST>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -5916,7 +5932,11 @@ pub(crate) unsafe fn cuda_backend_unchecked_partial_sum_ciphertexts_assign<
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_apply_univariate_lut<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_apply_univariate_lut<
|
||||
T: UnsignedInteger,
|
||||
KST: UnsignedInteger,
|
||||
B: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
output: &mut CudaSliceMut<T>,
|
||||
output_degrees: &mut Vec<u64>,
|
||||
@@ -5925,7 +5945,7 @@ pub(crate) unsafe fn cuda_backend_apply_univariate_lut<T: UnsignedInteger, B: Nu
|
||||
input_lut: &[T],
|
||||
lut_degree: u64,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<KST>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -6023,7 +6043,11 @@ pub(crate) unsafe fn cuda_backend_apply_univariate_lut<T: UnsignedInteger, B: Nu
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_apply_many_univariate_lut<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_apply_many_univariate_lut<
|
||||
T: UnsignedInteger,
|
||||
KST: UnsignedInteger,
|
||||
B: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
output: &mut CudaSliceMut<T>,
|
||||
output_degrees: &mut Vec<u64>,
|
||||
@@ -6032,7 +6056,7 @@ pub(crate) unsafe fn cuda_backend_apply_many_univariate_lut<T: UnsignedInteger,
|
||||
input_lut: &[T],
|
||||
lut_degree: u64,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
keyswitch_key: &CudaVec<KST>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -7250,14 +7274,18 @@ pub(crate) unsafe fn cuda_backend_extend_radix_with_trivial_zero_blocks_msb(
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_noise_squashing<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_noise_squashing<
|
||||
T: UnsignedInteger,
|
||||
KST: UnsignedInteger,
|
||||
B: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
output: &mut CudaSliceMut<T>,
|
||||
output_degrees: &mut Vec<u64>,
|
||||
output_noise_levels: &mut Vec<u64>,
|
||||
input: &CudaSlice<u64>,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<u64>,
|
||||
keyswitch_key: &CudaVec<KST>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
@@ -7369,12 +7397,12 @@ pub(crate) unsafe fn cuda_backend_noise_squashing<T: UnsignedInteger, B: Numeric
|
||||
/// that were inside that vector of compact list. Handling the input this way removes the need
|
||||
/// to process multiple compact lists separately, simplifying GPU-based operations. The variable
|
||||
/// name `lwe_flattened_compact_array_in` makes this intent explicit.
|
||||
pub(crate) unsafe fn cuda_backend_expand<T: UnsignedInteger, B: Numeric>(
|
||||
pub(crate) unsafe fn cuda_backend_expand<T: UnsignedInteger, KST: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array_out: &mut CudaLweCiphertextList<T>,
|
||||
lwe_flattened_compact_array_in: &CudaVec<T>,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
computing_ks_key: &CudaVec<T>,
|
||||
computing_ks_key: &CudaVec<KST>,
|
||||
casting_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
|
||||
@@ -36,14 +36,18 @@ impl<Scalar: UnsignedInteger> CudaBootstrappingKey<Scalar> {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum CudaDynamicKeyswitchingKey {
|
||||
Standard(CudaLweKeyswitchKey<u64>),
|
||||
KeySwitch32(CudaLweKeyswitchKey<u32>),
|
||||
}
|
||||
/// A structure containing the server public key.
|
||||
///
|
||||
/// The server key is generated by the client and is meant to be published: the client
|
||||
/// sends it to the server so it can compute homomorphic circuits.
|
||||
// #[derive(PartialEq, Serialize, Deserialize)]
|
||||
pub struct CudaServerKey {
|
||||
pub key_switching_key: CudaLweKeyswitchKey<u64>,
|
||||
pub bootstrapping_key: CudaBootstrappingKey<u64>,
|
||||
pub key_switching_key: CudaDynamicKeyswitchingKey,
|
||||
pub bootstrapping_key: CudaBootstrappingKey<u64>, // the GGSW of the BSK
|
||||
// Size of the message buffer
|
||||
pub message_modulus: MessageModulus,
|
||||
// Size of the carry buffer
|
||||
@@ -180,7 +184,7 @@ impl CudaServerKey {
|
||||
|
||||
// Pack the keys in the server key set:
|
||||
Self {
|
||||
key_switching_key: d_key_switching_key,
|
||||
key_switching_key: CudaDynamicKeyswitchingKey::Standard(d_key_switching_key),
|
||||
bootstrapping_key: d_bootstrapping_key,
|
||||
message_modulus: std_cks.parameters.message_modulus(),
|
||||
carry_modulus: std_cks.parameters.carry_modulus(),
|
||||
@@ -239,57 +243,108 @@ impl CudaServerKey {
|
||||
max_noise_level,
|
||||
} = cpu_key.key.clone();
|
||||
|
||||
// Generate a regular keyset and convert to the GPU
|
||||
let CompressedAtomicPatternServerKey::Standard(std_key) = compressed_ap_server_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
let ciphertext_modulus = compressed_ap_server_key.ciphertext_modulus();
|
||||
match compressed_ap_server_key {
|
||||
CompressedAtomicPatternServerKey::Standard(std_key) => {
|
||||
let (key_switching_key, bootstrapping_key, pbs_order) = std_key.into_raw_parts();
|
||||
|
||||
let ciphertext_modulus = std_key.ciphertext_modulus();
|
||||
let (key_switching_key, bootstrapping_key, pbs_order) = std_key.into_raw_parts();
|
||||
let h_key_switching_key = key_switching_key.par_decompress_into_lwe_keyswitch_key();
|
||||
let key_switching_key =
|
||||
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
|
||||
let bootstrapping_key = match bootstrapping_key {
|
||||
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::Classic{ bsk: h_bootstrap_key, modulus_switch_noise_reduction_key, } => {
|
||||
|
||||
let h_key_switching_key = key_switching_key.par_decompress_into_lwe_keyswitch_key();
|
||||
let key_switching_key =
|
||||
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
|
||||
let bootstrapping_key = match bootstrapping_key {
|
||||
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::Classic{ bsk: h_bootstrap_key, modulus_switch_noise_reduction_key, } => {
|
||||
let modulus_switch_noise_reduction_configuration = match modulus_switch_noise_reduction_key {
|
||||
CompressedModulusSwitchConfiguration::Standard => None,
|
||||
CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_modulus_switch_noise_reduction_key) => panic!("Drift noise reduction is not supported on GPU"),
|
||||
CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => Some(CudaModulusSwitchNoiseReductionConfiguration::Centered),
|
||||
};
|
||||
|
||||
let modulus_switch_noise_reduction_configuration = match modulus_switch_noise_reduction_key {
|
||||
CompressedModulusSwitchConfiguration::Standard => None,
|
||||
CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_modulus_switch_noise_reduction_key) => panic!("Drift noise reduction is not supported on GPU"),
|
||||
CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => Some(CudaModulusSwitchNoiseReductionConfiguration::Centered),
|
||||
let standard_bootstrapping_key = h_bootstrap_key.par_decompress_into_lwe_bootstrap_key();
|
||||
|
||||
let d_bootstrap_key =
|
||||
CudaLweBootstrapKey::from_lwe_bootstrap_key(&standard_bootstrapping_key, modulus_switch_noise_reduction_configuration, streams);
|
||||
|
||||
CudaBootstrappingKey::Classic(d_bootstrap_key)
|
||||
}
|
||||
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::MultiBit {
|
||||
seeded_bsk: bootstrapping_key,
|
||||
deterministic_execution: _,
|
||||
} => {
|
||||
let standard_bootstrapping_key =
|
||||
bootstrapping_key.par_decompress_into_lwe_multi_bit_bootstrap_key();
|
||||
|
||||
let d_bootstrap_key =
|
||||
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
|
||||
&standard_bootstrapping_key, streams);
|
||||
|
||||
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
|
||||
}
|
||||
};
|
||||
|
||||
let standard_bootstrapping_key = h_bootstrap_key.par_decompress_into_lwe_bootstrap_key();
|
||||
Self {
|
||||
key_switching_key: CudaDynamicKeyswitchingKey::Standard(key_switching_key),
|
||||
bootstrapping_key,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
max_degree,
|
||||
max_noise_level,
|
||||
ciphertext_modulus,
|
||||
pbs_order,
|
||||
}
|
||||
}
|
||||
CompressedAtomicPatternServerKey::KeySwitch32(ks32_key) => {
|
||||
let key_switching_key = ks32_key.key_switching_key();
|
||||
let bootstrapping_key = ks32_key.bootstrapping_key();
|
||||
|
||||
let d_bootstrap_key =
|
||||
let h_key_switching_key = key_switching_key
|
||||
.as_view()
|
||||
.par_decompress_into_lwe_keyswitch_key();
|
||||
let key_switching_key =
|
||||
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
|
||||
|
||||
let bootstrapping_key = match bootstrapping_key {
|
||||
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::Classic{ bsk: h_bootstrap_key, modulus_switch_noise_reduction_key, } => {
|
||||
|
||||
let modulus_switch_noise_reduction_configuration = match modulus_switch_noise_reduction_key {
|
||||
CompressedModulusSwitchConfiguration::Standard => None,
|
||||
CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_modulus_switch_noise_reduction_key) => panic!("Drift noise reduction is not supported on GPU"),
|
||||
CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => Some(CudaModulusSwitchNoiseReductionConfiguration::Centered),
|
||||
};
|
||||
|
||||
let standard_bootstrapping_key = h_bootstrap_key.as_view().par_decompress_into_lwe_bootstrap_key();
|
||||
|
||||
let d_bootstrap_key =
|
||||
CudaLweBootstrapKey::from_lwe_bootstrap_key(&standard_bootstrapping_key, modulus_switch_noise_reduction_configuration, streams);
|
||||
|
||||
CudaBootstrappingKey::Classic(d_bootstrap_key)
|
||||
}
|
||||
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::MultiBit {
|
||||
seeded_bsk: bootstrapping_key,
|
||||
deterministic_execution: _,
|
||||
} => {
|
||||
let standard_bootstrapping_key =
|
||||
bootstrapping_key.par_decompress_into_lwe_multi_bit_bootstrap_key();
|
||||
CudaBootstrappingKey::Classic(d_bootstrap_key)
|
||||
}
|
||||
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::MultiBit {
|
||||
seeded_bsk: bootstrapping_key,
|
||||
deterministic_execution: _,
|
||||
} => {
|
||||
let standard_bootstrapping_key =
|
||||
bootstrapping_key.as_view().par_decompress_into_lwe_multi_bit_bootstrap_key();
|
||||
|
||||
let d_bootstrap_key =
|
||||
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
|
||||
let d_bootstrap_key =
|
||||
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
|
||||
&standard_bootstrapping_key, streams);
|
||||
|
||||
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
|
||||
}
|
||||
};
|
||||
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
|
||||
}
|
||||
};
|
||||
|
||||
Self {
|
||||
key_switching_key,
|
||||
bootstrapping_key,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
max_degree,
|
||||
max_noise_level,
|
||||
ciphertext_modulus,
|
||||
pbs_order,
|
||||
Self {
|
||||
key_switching_key: CudaDynamicKeyswitchingKey::KeySwitch32(key_switching_key),
|
||||
bootstrapping_key,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
max_degree,
|
||||
max_noise_level,
|
||||
ciphertext_modulus,
|
||||
pbs_order: PBSOrder::KeyswitchBootstrap,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{cuda_backend_unchecked_signed_abs_assign, PBSType};
|
||||
|
||||
impl CudaServerKey {
|
||||
@@ -11,6 +13,10 @@ impl CudaServerKey {
|
||||
{
|
||||
let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -18,19 +24,15 @@ impl CudaServerKey {
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_blocks,
|
||||
@@ -44,19 +46,15 @@ impl CudaServerKey {
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_blocks,
|
||||
|
||||
@@ -5,7 +5,9 @@ use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
|
||||
CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_add_and_propagate_single_carry_assign,
|
||||
cuda_backend_get_add_and_propagate_single_carry_assign_size_on_gpu,
|
||||
@@ -127,6 +129,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -134,8 +140,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -151,8 +157,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -181,8 +187,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -200,8 +206,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -331,6 +337,9 @@ impl CudaServerKey {
|
||||
let radix_count_in_vec = ciphertexts.len();
|
||||
|
||||
let mut terms = CudaRadixCiphertext::from_radix_ciphertext_vec(ciphertexts, streams);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -341,16 +350,14 @@ impl CudaServerKey {
|
||||
&mut terms,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_blocks.0 as u32,
|
||||
@@ -367,16 +374,14 @@ impl CudaServerKey {
|
||||
&mut terms,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_blocks.0 as u32,
|
||||
@@ -694,6 +699,9 @@ impl CudaServerKey {
|
||||
let aux_block: T = self.create_trivial_zero_radix(1, streams);
|
||||
let in_carry: &CudaRadixCiphertext =
|
||||
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -705,12 +713,12 @@ impl CudaServerKey {
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -731,12 +739,12 @@ impl CudaServerKey {
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
|
||||
@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::{
|
||||
check_valid_cuda_malloc, check_valid_cuda_malloc_assert_oom, CudaStreams,
|
||||
};
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::{
|
||||
@@ -271,6 +273,10 @@ impl CudaServerKey {
|
||||
result.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -283,14 +289,14 @@ impl CudaServerKey {
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -308,14 +314,14 @@ impl CudaServerKey {
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -334,6 +340,10 @@ impl CudaServerKey {
|
||||
sbox_parallelism: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_ctr_encrypt_size_on_gpu(
|
||||
streams,
|
||||
@@ -344,8 +354,8 @@ impl CudaServerKey {
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -362,8 +372,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -392,6 +402,10 @@ impl CudaServerKey {
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -400,14 +414,14 @@ impl CudaServerKey {
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -421,14 +435,14 @@ impl CudaServerKey {
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -442,6 +456,10 @@ impl CudaServerKey {
|
||||
}
|
||||
|
||||
pub fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_key_expansion_size_on_gpu(
|
||||
streams,
|
||||
@@ -450,8 +468,8 @@ impl CudaServerKey {
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -466,8 +484,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
|
||||
@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::{
|
||||
check_valid_cuda_malloc, check_valid_cuda_malloc_assert_oom, CudaStreams,
|
||||
};
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::{
|
||||
@@ -197,6 +199,10 @@ impl CudaServerKey {
|
||||
result.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -209,14 +215,14 @@ impl CudaServerKey {
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -234,14 +240,14 @@ impl CudaServerKey {
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -274,6 +280,10 @@ impl CudaServerKey {
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -282,14 +292,14 @@ impl CudaServerKey {
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -303,14 +313,14 @@ impl CudaServerKey {
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -324,6 +334,10 @@ impl CudaServerKey {
|
||||
}
|
||||
|
||||
pub fn get_key_expansion_256_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_aes_key_expansion_256_size_on_gpu(
|
||||
@@ -333,8 +347,8 @@ impl CudaServerKey {
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -350,8 +364,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_boolean_bitnot_assign, cuda_backend_boolean_bitop_assign,
|
||||
cuda_backend_get_bitop_size_on_gpu, cuda_backend_get_boolean_bitnot_size_on_gpu,
|
||||
@@ -324,6 +324,10 @@ impl CudaServerKey {
|
||||
ct_right.0.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -332,19 +336,15 @@ impl CudaServerKey {
|
||||
ct_left.0.as_mut(),
|
||||
ct_right.0.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -361,19 +361,15 @@ impl CudaServerKey {
|
||||
ct_left.0.as_mut(),
|
||||
ct_right.0.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -412,6 +408,10 @@ impl CudaServerKey {
|
||||
is_unchecked: bool,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -420,15 +420,15 @@ impl CudaServerKey {
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
d_bsk.output_lwe_dimension(),
|
||||
d_bsk.input_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
PBSType::Classical,
|
||||
@@ -442,15 +442,15 @@ impl CudaServerKey {
|
||||
&mut ct.0.ciphertext as &mut CudaRadixCiphertext,
|
||||
is_unchecked,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
d_multibit_bsk.output_lwe_dimension(),
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
PBSType::MultiBit,
|
||||
@@ -534,6 +534,10 @@ impl CudaServerKey {
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -542,19 +546,15 @@ impl CudaServerKey {
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -570,19 +570,15 @@ impl CudaServerKey {
|
||||
ct_left.as_mut(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -611,6 +607,10 @@ impl CudaServerKey {
|
||||
ct_left.0.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.0.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let boolean_bitop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_boolean_bitop_size_on_gpu(
|
||||
streams,
|
||||
@@ -618,14 +618,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -642,14 +638,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -679,6 +671,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -686,8 +682,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -703,8 +699,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -734,14 +730,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -756,14 +748,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -1287,6 +1275,10 @@ impl CudaServerKey {
|
||||
_ct: &CudaBooleanBlock,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let boolean_bitnot_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_boolean_bitnot_size_on_gpu(
|
||||
streams,
|
||||
@@ -1294,14 +1286,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
false,
|
||||
@@ -1317,14 +1305,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
false,
|
||||
@@ -1344,6 +1328,10 @@ impl CudaServerKey {
|
||||
ct: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
@@ -1354,8 +1342,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -1371,8 +1359,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_cmux_size_on_gpu, cuda_backend_get_full_propagate_assign_size_on_gpu,
|
||||
cuda_backend_unchecked_cmux, CudaServerKey, PBSType,
|
||||
@@ -20,6 +20,10 @@ impl CudaServerKey {
|
||||
let mut result: T = self
|
||||
.create_trivial_zero_radix(true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -30,19 +34,15 @@ impl CudaServerKey {
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -59,19 +59,15 @@ impl CudaServerKey {
|
||||
true_ct.as_ref(),
|
||||
false_ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -129,6 +125,10 @@ impl CudaServerKey {
|
||||
true_ct.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
false_ct.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -136,8 +136,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -153,8 +153,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -184,14 +184,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -205,14 +201,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::core_crypto::prelude::{LweBskGroupingFactor, LweCiphertextCount};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_comparison_size_on_gpu, cuda_backend_get_full_propagate_assign_size_on_gpu,
|
||||
cuda_backend_unchecked_comparison, ComparisonType, CudaServerKey, PBSType,
|
||||
@@ -45,6 +45,10 @@ impl CudaServerKey {
|
||||
let mut result =
|
||||
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -54,19 +58,15 @@ impl CudaServerKey {
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -83,19 +83,15 @@ impl CudaServerKey {
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -226,6 +222,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -233,8 +233,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -250,8 +250,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -273,6 +273,9 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let comparison_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_comparison_size_on_gpu(
|
||||
@@ -281,14 +284,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -305,14 +304,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -822,6 +817,10 @@ impl CudaServerKey {
|
||||
|
||||
let mut result = ct_left.duplicate(streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -831,19 +830,15 @@ impl CudaServerKey {
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
ComparisonType::MAX,
|
||||
@@ -860,19 +855,15 @@ impl CudaServerKey {
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
ComparisonType::MAX,
|
||||
@@ -902,6 +893,10 @@ impl CudaServerKey {
|
||||
|
||||
let mut result = ct_left.duplicate(streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -911,19 +906,15 @@ impl CudaServerKey {
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
ComparisonType::MIN,
|
||||
@@ -940,19 +931,15 @@ impl CudaServerKey {
|
||||
ct_left.as_ref(),
|
||||
ct_right.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
ComparisonType::MIN,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_div_rem_size_on_gpu, cuda_backend_get_full_propagate_assign_size_on_gpu,
|
||||
cuda_backend_unchecked_div_rem_assign, PBSType,
|
||||
@@ -18,6 +20,10 @@ impl CudaServerKey {
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
// TODO add asserts from `unchecked_div_rem_parallelized`
|
||||
let num_blocks = divisor.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
unsafe {
|
||||
@@ -31,19 +37,15 @@ impl CudaServerKey {
|
||||
divisor.as_ref(),
|
||||
T::IS_SIGNED,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_blocks,
|
||||
@@ -61,19 +63,15 @@ impl CudaServerKey {
|
||||
divisor.as_ref(),
|
||||
T::IS_SIGNED,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_blocks,
|
||||
@@ -227,6 +225,10 @@ impl CudaServerKey {
|
||||
numerator.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
divisor.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -234,8 +236,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -251,8 +253,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -283,14 +285,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -305,14 +303,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
|
||||
@@ -4,7 +4,9 @@ use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{cuda_backend_count_of_consecutive_bits, cuda_backend_ilog2, PBSType};
|
||||
use crate::integer::server_key::radix_parallel::ilog2::{BitValue, Direction};
|
||||
|
||||
@@ -34,6 +36,10 @@ impl CudaServerKey {
|
||||
let mut result: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(counter_num_blocks, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -42,12 +48,12 @@ impl CudaServerKey {
|
||||
result.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -65,12 +71,12 @@ impl CudaServerKey {
|
||||
result.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -190,6 +196,10 @@ impl CudaServerKey {
|
||||
let mut result: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(counter_num_blocks, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -201,12 +211,12 @@ impl CudaServerKey {
|
||||
trivial_ct_2.as_ref(),
|
||||
trivial_ct_m_minus_1_block.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -228,12 +238,12 @@ impl CudaServerKey {
|
||||
trivial_ct_2.as_ref(),
|
||||
trivial_ct_m_minus_1_block.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
d_multibit_bsk.grouping_factor,
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::integer::gpu::ciphertext::{
|
||||
CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_apply_many_univariate_lut, cuda_backend_apply_univariate_lut,
|
||||
cuda_backend_cast_to_signed, cuda_backend_cast_to_unsigned,
|
||||
@@ -182,9 +182,13 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let lwe_size = match self.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_size(),
|
||||
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(),
|
||||
PBSOrder::KeyswitchBootstrap => computing_ks_key.input_key_lwe_size(),
|
||||
PBSOrder::BootstrapKeyswitch => computing_ks_key.output_key_lwe_size(),
|
||||
};
|
||||
|
||||
let decomposer =
|
||||
@@ -235,6 +239,10 @@ impl CudaServerKey {
|
||||
let in_carry: &CudaRadixCiphertext =
|
||||
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -244,12 +252,12 @@ impl CudaServerKey {
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -269,12 +277,12 @@ impl CudaServerKey {
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -299,6 +307,10 @@ impl CudaServerKey {
|
||||
) {
|
||||
let ciphertext = ct.as_mut();
|
||||
let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -306,12 +318,12 @@ impl CudaServerKey {
|
||||
streams,
|
||||
ciphertext,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -327,12 +339,12 @@ impl CudaServerKey {
|
||||
streams,
|
||||
ciphertext,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -689,6 +701,10 @@ impl CudaServerKey {
|
||||
let mut output_degrees = vec![0_u64; num_output_blocks];
|
||||
let mut output_noise_levels = vec![0_u64; num_output_blocks];
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let num_ct_blocks = block_range.len() as u32;
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -702,14 +718,12 @@ impl CudaServerKey {
|
||||
lut.acc.as_ref(),
|
||||
lut.degree.0,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
&computing_ks_key.d_vec,
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_ct_blocks,
|
||||
@@ -730,14 +744,12 @@ impl CudaServerKey {
|
||||
lut.acc.as_ref(),
|
||||
lut.degree.0,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
&computing_ks_key.d_vec,
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_ct_blocks,
|
||||
@@ -860,6 +872,9 @@ impl CudaServerKey {
|
||||
.unwrap();
|
||||
let mut output_degrees = vec![0_u64; num_ct_blocks * function_count];
|
||||
let mut output_noise_levels = vec![0_u64; num_ct_blocks * function_count];
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -873,14 +888,12 @@ impl CudaServerKey {
|
||||
lut.acc.as_ref(),
|
||||
lut.input_max_degree.0,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
&computing_ks_key.d_vec,
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_ct_blocks as u32,
|
||||
@@ -903,14 +916,12 @@ impl CudaServerKey {
|
||||
lut.acc.as_ref(),
|
||||
lut.input_max_degree.0,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
&computing_ks_key.d_vec,
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_ct_blocks as u32,
|
||||
@@ -1001,6 +1012,9 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix(target_num_blocks, streams);
|
||||
|
||||
let requires_full_propagate = !source.block_carries_are_empty();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -1013,17 +1027,13 @@ impl CudaServerKey {
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1040,17 +1050,13 @@ impl CudaServerKey {
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -1114,6 +1120,10 @@ impl CudaServerKey {
|
||||
let mut output_ct: CudaSignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(target_num_blocks, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -1123,16 +1133,14 @@ impl CudaServerKey {
|
||||
source.as_ref(),
|
||||
T::IS_SIGNED,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1147,16 +1155,14 @@ impl CudaServerKey {
|
||||
source.as_ref(),
|
||||
T::IS_SIGNED,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -1202,6 +1208,10 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
),
|
||||
};
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &d_bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(bsk) => {
|
||||
@@ -1212,16 +1222,14 @@ impl CudaServerKey {
|
||||
&mut output_noise_levels,
|
||||
&input_slice,
|
||||
&bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
&computing_ks_key.d_vec,
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
bsk.glwe_dimension,
|
||||
bsk.polynomial_size,
|
||||
input_glwe_dimension,
|
||||
input_polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
bsk.decomp_level_count,
|
||||
bsk.decomp_base_log,
|
||||
num_output_blocks as u32,
|
||||
@@ -1241,16 +1249,14 @@ impl CudaServerKey {
|
||||
&mut output_noise_levels,
|
||||
&input_slice,
|
||||
&mb_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
&computing_ks_key.d_vec,
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
mb_bsk.glwe_dimension,
|
||||
mb_bsk.polynomial_size,
|
||||
input_glwe_dimension,
|
||||
input_polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
mb_bsk.decomp_level_count,
|
||||
mb_bsk.decomp_base_log,
|
||||
num_output_blocks as u32,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_mul_size_on_gpu,
|
||||
cuda_backend_unchecked_mul_assign, PBSType,
|
||||
@@ -74,6 +76,10 @@ impl CudaServerKey {
|
||||
|
||||
let is_boolean_left = ct_left.holds_boolean_value();
|
||||
let is_boolean_right = ct_right.holds_boolean_value();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -84,7 +90,7 @@ impl CudaServerKey {
|
||||
ct_right.as_ref(),
|
||||
is_boolean_right,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
@@ -92,8 +98,8 @@ impl CudaServerKey {
|
||||
d_bsk.polynomial_size(),
|
||||
d_bsk.decomp_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
num_blocks,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -108,7 +114,7 @@ impl CudaServerKey {
|
||||
ct_right.as_ref(),
|
||||
is_boolean_right,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
@@ -116,8 +122,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
num_blocks,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -233,6 +239,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -240,8 +250,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -257,8 +267,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -291,14 +301,12 @@ impl CudaServerKey {
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -311,14 +319,12 @@ impl CudaServerKey {
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
|
||||
@@ -3,7 +3,9 @@ use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
|
||||
CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::core_crypto::commons::generators::DeterministicSeeder;
|
||||
@@ -351,6 +353,9 @@ impl CudaServerKey {
|
||||
}
|
||||
|
||||
let message_bits_count = self.message_modulus.0.ilog2();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -364,8 +369,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -387,8 +392,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.grouping_factor,
|
||||
@@ -484,6 +489,10 @@ impl CudaServerKey {
|
||||
let mut result: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(num_blocks_output as usize, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -496,12 +505,12 @@ impl CudaServerKey {
|
||||
has_at_least_one_set.as_slice(),
|
||||
num_input_random_bits as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -523,12 +532,12 @@ impl CudaServerKey {
|
||||
has_at_least_one_set.as_slice(),
|
||||
num_input_random_bits as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.grouping_factor,
|
||||
@@ -553,6 +562,9 @@ impl CudaServerKey {
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
let message_bits = self.message_modulus.0.ilog2();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_grouped_oprf_size_on_gpu(
|
||||
@@ -561,8 +573,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -579,8 +591,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.grouping_factor,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_rotate_left_size_on_gpu,
|
||||
cuda_backend_get_rotate_right_size_on_gpu, cuda_backend_unchecked_rotate_left_assign,
|
||||
@@ -19,6 +19,9 @@ impl CudaServerKey {
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -28,19 +31,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -56,19 +55,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -106,6 +101,9 @@ impl CudaServerKey {
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -115,19 +113,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -143,19 +137,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -437,6 +427,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -444,8 +438,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -461,8 +455,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -492,14 +486,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -515,14 +505,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -550,6 +536,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -557,8 +547,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -574,8 +564,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -605,14 +595,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -628,14 +614,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
|
||||
@@ -6,7 +6,9 @@ use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu,
|
||||
cuda_backend_get_propagate_single_carry_assign_size_on_gpu,
|
||||
@@ -179,6 +181,10 @@ impl CudaServerKey {
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
@@ -189,8 +195,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -206,8 +212,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -228,8 +234,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -247,8 +253,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_scalar_bitop_size_on_gpu,
|
||||
cuda_backend_unchecked_scalar_bitop_assign, BitOpType, CudaServerKey, PBSType,
|
||||
@@ -29,6 +29,10 @@ impl CudaServerKey {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let clear_blocks = unsafe { CudaVec::from_cpu_async(&h_clear_blocks, streams, 0) };
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -38,19 +42,15 @@ impl CudaServerKey {
|
||||
&clear_blocks,
|
||||
&h_clear_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -67,19 +67,15 @@ impl CudaServerKey {
|
||||
&clear_blocks,
|
||||
&h_clear_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -241,6 +237,9 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
@@ -252,8 +251,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -269,8 +268,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -291,14 +290,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
@@ -314,14 +309,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
|
||||
@@ -6,7 +6,9 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_unchecked_are_all_comparisons_block_true,
|
||||
cuda_backend_unchecked_is_at_least_one_comparisons_block_true,
|
||||
@@ -167,6 +169,9 @@ impl CudaServerKey {
|
||||
|
||||
let mut result =
|
||||
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -178,19 +183,15 @@ impl CudaServerKey {
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
@@ -209,19 +210,15 @@ impl CudaServerKey {
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
@@ -320,6 +317,9 @@ impl CudaServerKey {
|
||||
unsafe { CudaVec::from_cpu_async(&scalar_blocks, streams, 0) };
|
||||
|
||||
let mut result = ct.duplicate(streams);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -331,19 +331,15 @@ impl CudaServerKey {
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
@@ -362,19 +358,15 @@ impl CudaServerKey {
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
@@ -400,6 +392,10 @@ impl CudaServerKey {
|
||||
{
|
||||
let ct_res: T = self.create_trivial_radix(0, 1, streams);
|
||||
let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -408,19 +404,15 @@ impl CudaServerKey {
|
||||
boolean_res.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -434,19 +426,15 @@ impl CudaServerKey {
|
||||
boolean_res.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -469,6 +457,10 @@ impl CudaServerKey {
|
||||
{
|
||||
let ct_res: T = self.create_trivial_radix(0, 1, streams);
|
||||
let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -477,19 +469,15 @@ impl CudaServerKey {
|
||||
boolean_res.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -503,19 +491,15 @@ impl CudaServerKey {
|
||||
boolean_res.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu,
|
||||
cuda_backend_get_scalar_div_rem_size_on_gpu, cuda_backend_get_scalar_div_size_on_gpu,
|
||||
@@ -85,6 +85,9 @@ impl CudaServerKey {
|
||||
);
|
||||
|
||||
let mut quotient = numerator.duplicate(streams);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -93,15 +96,15 @@ impl CudaServerKey {
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -114,15 +117,15 @@ impl CudaServerKey {
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -220,6 +223,9 @@ impl CudaServerKey {
|
||||
numerator.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
streams,
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -229,15 +235,15 @@ impl CudaServerKey {
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -251,15 +257,15 @@ impl CudaServerKey {
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -420,6 +426,9 @@ impl CudaServerKey {
|
||||
);
|
||||
|
||||
let mut quotient: CudaSignedRadixCiphertext = numerator.duplicate(streams);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -428,15 +437,15 @@ impl CudaServerKey {
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -449,15 +458,15 @@ impl CudaServerKey {
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -555,6 +564,9 @@ impl CudaServerKey {
|
||||
numerator.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
streams,
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -564,15 +576,15 @@ impl CudaServerKey {
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -586,15 +598,15 @@ impl CudaServerKey {
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -758,6 +770,10 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
Scalar::BITS
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if numerator.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
@@ -768,8 +784,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -785,8 +801,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -808,8 +824,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -826,8 +842,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -864,6 +880,10 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
Scalar::BITS
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_scalar_div_rem_size_on_gpu(
|
||||
streams,
|
||||
@@ -873,8 +893,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -891,8 +911,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -940,6 +960,9 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
"The scalar divisor type must have a number of bits that is\
|
||||
>= to the number of bits encrypted in the ciphertext"
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_signed_scalar_div_size_on_gpu(
|
||||
@@ -952,8 +975,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
LweBskGroupingFactor(0),
|
||||
num_blocks,
|
||||
PBSType::Classical,
|
||||
@@ -970,8 +993,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
d_multibit_bsk.grouping_factor,
|
||||
num_blocks,
|
||||
PBSType::MultiBit,
|
||||
@@ -1001,6 +1024,9 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
"The scalar divisor type must have a number of bits that is\
|
||||
>= to the number of bits encrypted in the ciphertext"
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -1012,8 +1038,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -1031,8 +1057,8 @@ encrypted bits: {numerator_bits}, scalar bits: {}
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
|
||||
@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_scalar_mul_size_on_gpu,
|
||||
cuda_backend_unchecked_scalar_mul, PBSType,
|
||||
@@ -107,6 +109,9 @@ impl CudaServerKey {
|
||||
if decomposed_scalar.is_empty() {
|
||||
return;
|
||||
}
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -117,18 +122,16 @@ impl CudaServerKey {
|
||||
decomposed_scalar.as_slice(),
|
||||
has_at_least_one_set.as_slice(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
decomposed_scalar.len() as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -142,18 +145,16 @@ impl CudaServerKey {
|
||||
decomposed_scalar.as_slice(),
|
||||
has_at_least_one_set.as_slice(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
decomposed_scalar.len() as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -246,6 +247,9 @@ impl CudaServerKey {
|
||||
// than multiplying
|
||||
return self.get_scalar_left_shift_size_on_gpu(ct, streams);
|
||||
}
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
@@ -257,8 +261,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -274,8 +278,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -302,13 +306,11 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -322,13 +324,11 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu,
|
||||
cuda_backend_get_scalar_rotate_left_size_on_gpu,
|
||||
@@ -38,6 +38,10 @@ impl CudaServerKey {
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -46,19 +50,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -73,19 +73,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -125,6 +121,10 @@ impl CudaServerKey {
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -133,19 +133,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -160,19 +156,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -237,6 +229,9 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
@@ -248,8 +243,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -265,8 +260,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -286,14 +281,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -309,14 +300,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -334,6 +321,9 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
@@ -345,8 +335,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -362,8 +352,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -382,14 +372,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -403,14 +389,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::{CastFrom, LweBskGroupingFactor};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu,
|
||||
cuda_backend_get_scalar_arithmetic_right_shift_size_on_gpu,
|
||||
@@ -76,6 +76,9 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -85,19 +88,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -112,19 +111,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -200,6 +195,9 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
if T::IS_SIGNED {
|
||||
@@ -210,19 +208,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -236,19 +230,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -265,19 +255,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -292,19 +278,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -451,10 +433,16 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key
|
||||
else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -462,8 +450,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -479,8 +467,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -492,6 +480,7 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let scalar_shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_scalar_left_shift_size_on_gpu(
|
||||
streams,
|
||||
@@ -499,14 +488,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -521,14 +506,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -546,6 +527,9 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
@@ -557,8 +541,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -574,8 +558,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -596,14 +580,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -619,14 +599,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -645,14 +621,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -668,14 +640,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaDynamicKeyswitchingKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_get_left_shift_size_on_gpu,
|
||||
cuda_backend_get_right_shift_size_on_gpu, cuda_backend_unchecked_left_shift_assign,
|
||||
@@ -19,6 +19,9 @@ impl CudaServerKey {
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -28,19 +31,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -56,19 +55,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -106,6 +101,9 @@ impl CudaServerKey {
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -115,19 +113,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -143,19 +137,15 @@ impl CudaServerKey {
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -435,6 +425,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -442,8 +436,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -459,8 +453,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -482,6 +476,9 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_left_shift_size_on_gpu(
|
||||
@@ -490,14 +487,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -513,14 +506,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -548,6 +537,10 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_full_propagate_assign_size_on_gpu(
|
||||
@@ -555,8 +548,8 @@ impl CudaServerKey {
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -572,8 +565,8 @@ impl CudaServerKey {
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
@@ -595,6 +588,9 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
let shift_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_right_shift_size_on_gpu(
|
||||
@@ -603,14 +599,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
@@ -626,14 +618,10 @@ impl CudaServerKey {
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
|
||||
CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::server_key::CudaServerKey;
|
||||
use crate::integer::gpu::server_key::{CudaDynamicKeyswitchingKey, CudaServerKey};
|
||||
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
@@ -296,6 +296,9 @@ impl CudaServerKey {
|
||||
let aux_block: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, stream);
|
||||
let in_carry_dvec =
|
||||
INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -307,12 +310,12 @@ impl CudaServerKey {
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
@@ -332,12 +335,12 @@ impl CudaServerKey {
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
@@ -374,6 +377,9 @@ impl CudaServerKey {
|
||||
let aux_block: T = self.create_trivial_zero_radix(1, streams);
|
||||
let in_carry: &CudaRadixCiphertext =
|
||||
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -385,12 +391,12 @@ impl CudaServerKey {
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
@@ -411,12 +417,12 @@ impl CudaServerKey {
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
|
||||
@@ -18,7 +18,9 @@ use crate::core_crypto::prelude::*;
|
||||
use crate::integer::gpu::ciphertext::info::CudaBlockInfo;
|
||||
use crate::integer::gpu::ciphertext::CudaRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::radix::{CudaNoiseSquashingKey, CudaRadixCiphertextInfo};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_centered_modulus_switch_64, unchecked_small_scalar_mul_integer_async, CudaStreams,
|
||||
};
|
||||
@@ -417,10 +419,14 @@ impl AllocateLweKeyswitchResult for CudaServerKey {
|
||||
&self,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) -> Self::Output {
|
||||
let output_lwe_dimension = self
|
||||
.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension();
|
||||
let output_lwe_dimension = match &self.key_switching_key {
|
||||
CudaDynamicKeyswitchingKey::Standard(std_key) => {
|
||||
std_key.output_key_lwe_size().to_lwe_dimension()
|
||||
}
|
||||
CudaDynamicKeyswitchingKey::KeySwitch32(ks32_key) => {
|
||||
ks32_key.output_key_lwe_size().to_lwe_dimension()
|
||||
}
|
||||
};
|
||||
let lwe_ciphertext_count = LweCiphertextCount(1);
|
||||
let ciphertext_modulus = self.ciphertext_modulus;
|
||||
|
||||
@@ -444,12 +450,39 @@ impl LweKeyswitch<CudaDynLwe, CudaDynLwe> for CudaServerKey {
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
match (input, output) {
|
||||
(CudaDynLwe::U64(input_cuda_lwe), CudaDynLwe::U64(output_cuda_lwe)) => {
|
||||
(CudaDynLwe::U64(input_cuda_lwe), CudaDynLwe::U32(output_cuda_lwe)) => {
|
||||
let CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) =
|
||||
&self.key_switching_key
|
||||
else {
|
||||
panic!("Expecting 32b KSK in Cuda noise simulation tests when LWE is 32b");
|
||||
};
|
||||
|
||||
let input_indexes = CudaVec::new(1, &side_resources.streams, 0);
|
||||
let output_indexes = CudaVec::new(1, &side_resources.streams, 0);
|
||||
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
&self.key_switching_key,
|
||||
computing_ks_key,
|
||||
input_cuda_lwe,
|
||||
output_cuda_lwe,
|
||||
&input_indexes,
|
||||
&output_indexes,
|
||||
false,
|
||||
&side_resources.streams,
|
||||
false,
|
||||
);
|
||||
}
|
||||
(CudaDynLwe::U64(input_cuda_lwe), CudaDynLwe::U64(output_cuda_lwe)) => {
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) =
|
||||
&self.key_switching_key
|
||||
else {
|
||||
panic!("Expecting 64b KSK in Cuda noise simulation tests when LWE is 64b");
|
||||
};
|
||||
|
||||
let input_indexes = CudaVec::new(1, &side_resources.streams, 0);
|
||||
let output_indexes = CudaVec::new(1, &side_resources.streams, 0);
|
||||
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
computing_ks_key,
|
||||
input_cuda_lwe,
|
||||
output_cuda_lwe,
|
||||
&input_indexes,
|
||||
|
||||
@@ -2,7 +2,9 @@ use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_unchecked_all_eq_slices, cuda_backend_unchecked_contains_sub_slice, PBSType,
|
||||
};
|
||||
@@ -56,6 +58,10 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -65,19 +71,15 @@ impl CudaServerKey {
|
||||
lhs,
|
||||
rhs,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -92,19 +94,15 @@ impl CudaServerKey {
|
||||
lhs,
|
||||
rhs,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -269,6 +267,10 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -278,19 +280,15 @@ impl CudaServerKey {
|
||||
lhs,
|
||||
rhs,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -305,19 +303,15 @@ impl CudaServerKey {
|
||||
lhs,
|
||||
rhs,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
|
||||
@@ -3,7 +3,9 @@ use crate::core_crypto::prelude::{LweBskGroupingFactor, UnsignedInteger};
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_get_unchecked_match_value_or_size_on_gpu,
|
||||
cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_contains,
|
||||
@@ -52,6 +54,10 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams),
|
||||
);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -64,17 +70,13 @@ impl CudaServerKey {
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -92,17 +94,13 @@ impl CudaServerKey {
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -130,6 +128,10 @@ impl CudaServerKey {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_unchecked_match_value_size_on_gpu(
|
||||
@@ -138,14 +140,10 @@ impl CudaServerKey {
|
||||
matches,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -162,14 +160,10 @@ impl CudaServerKey {
|
||||
matches,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -296,6 +290,10 @@ impl CudaServerKey {
|
||||
let mut result: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(final_num_blocks, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -308,17 +306,13 @@ impl CudaServerKey {
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -336,17 +330,13 @@ impl CudaServerKey {
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -375,6 +365,10 @@ impl CudaServerKey {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_unchecked_match_value_or_size_on_gpu(
|
||||
@@ -384,14 +378,10 @@ impl CudaServerKey {
|
||||
or_value,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
@@ -409,14 +399,10 @@ impl CudaServerKey {
|
||||
or_value,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
@@ -519,6 +505,9 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams)
|
||||
.into_inner(),
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -529,19 +518,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
value.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -556,19 +541,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
value.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -675,6 +656,9 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams)
|
||||
.into_inner(),
|
||||
);
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -685,19 +669,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
clear,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -712,19 +692,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
clear,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -822,6 +798,9 @@ impl CudaServerKey {
|
||||
|
||||
let ct_res: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams);
|
||||
let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -832,19 +811,15 @@ impl CudaServerKey {
|
||||
ct.as_ref(),
|
||||
clears,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -859,19 +834,15 @@ impl CudaServerKey {
|
||||
ct.as_ref(),
|
||||
clears,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -982,6 +953,9 @@ impl CudaServerKey {
|
||||
let trivial_bool =
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -993,19 +967,15 @@ impl CudaServerKey {
|
||||
ct.as_ref(),
|
||||
clears,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1021,19 +991,15 @@ impl CudaServerKey {
|
||||
ct.as_ref(),
|
||||
clears,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -1154,6 +1120,9 @@ impl CudaServerKey {
|
||||
let trivial_bool =
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -1165,19 +1134,15 @@ impl CudaServerKey {
|
||||
ct.as_ref(),
|
||||
clears,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1193,19 +1158,15 @@ impl CudaServerKey {
|
||||
ct.as_ref(),
|
||||
clears,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -1313,6 +1274,9 @@ impl CudaServerKey {
|
||||
|
||||
let trivial_bool: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
@@ -1324,19 +1288,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
value.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1352,19 +1312,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
value.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -1501,6 +1457,10 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -1511,19 +1471,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
clear,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1539,19 +1495,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
clear,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -1677,6 +1629,10 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -1687,19 +1643,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
clear,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1715,19 +1667,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
clear,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
@@ -1851,6 +1799,10 @@ impl CudaServerKey {
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
@@ -1861,19 +1813,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
value.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
@@ -1889,19 +1837,15 @@ impl CudaServerKey {
|
||||
cts,
|
||||
value.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
computing_ks_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
|
||||
Reference in New Issue
Block a user