mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
feat(gpu): support keyswitch 64/32
This commit is contained in:
committed by
Andrei Stoian
parent
14d49f0891
commit
78d1ce18c1
@@ -339,7 +339,10 @@ mod cuda {
|
||||
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
|
||||
fn cuda_keyswitch<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize>(
|
||||
fn cuda_keyswitch_classical_and_gemm<
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize,
|
||||
KeyswitchScalar: UnsignedTorus + CastFrom<Scalar>,
|
||||
>(
|
||||
criterion: &mut Criterion,
|
||||
parameters: &[(String, CryptoParametersRecord<Scalar>)],
|
||||
) {
|
||||
@@ -361,27 +364,57 @@ mod cuda {
|
||||
let ks_decomp_base_log = params.ks_base_log.unwrap();
|
||||
let ks_decomp_level_count = params.ks_level.unwrap();
|
||||
|
||||
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut secret_generator,
|
||||
);
|
||||
let lwe_noise_distribution_ksk = match params.lwe_noise_distribution.unwrap() {
|
||||
DynamicDistribution::Gaussian(gaussian_lwe_noise_distribution) => {
|
||||
DynamicDistribution::<KeyswitchScalar>::new_gaussian(
|
||||
gaussian_lwe_noise_distribution.standard_dev(),
|
||||
)
|
||||
}
|
||||
DynamicDistribution::TUniform(uniform_lwe_noise_distribution) => {
|
||||
DynamicDistribution::<KeyswitchScalar>::new_t_uniform(
|
||||
match KeyswitchScalar::BITS {
|
||||
32 => uniform_lwe_noise_distribution.bound_log2() - 32,
|
||||
64 => uniform_lwe_noise_distribution.bound_log2(),
|
||||
_ => panic!("Unsupported Keyswitch scalar input dtype"),
|
||||
},
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
let lwe_sk: LweSecretKeyOwned<KeyswitchScalar> =
|
||||
allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut secret_generator,
|
||||
);
|
||||
|
||||
let glwe_sk: GlweSecretKeyOwned<KeyswitchScalar> =
|
||||
allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
|
||||
|
||||
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
&big_lwe_sk,
|
||||
&lwe_sk,
|
||||
ks_decomp_base_log,
|
||||
ks_decomp_level_count,
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
lwe_noise_distribution_ksk,
|
||||
CiphertextModulus::new_native(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let glwe_sk_64: GlweSecretKeyOwned<Scalar> =
|
||||
allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
|
||||
let big_lwe_sk_64 = glwe_sk_64.into_lwe_secret_key();
|
||||
let ciphertext_modulus_out = CiphertextModulus::<KeyswitchScalar>::new_native();
|
||||
|
||||
let cpu_keys: CpuKeys<_> = CpuKeysBuilder::new()
|
||||
.keyswitch_key(ksk_big_to_small)
|
||||
.build();
|
||||
@@ -394,7 +427,7 @@ mod cuda {
|
||||
let gpu_keys = CudaLocalKeys::from_cpu_keys(&cpu_keys, None, &streams);
|
||||
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
&big_lwe_sk_64,
|
||||
Plaintext(Scalar::ONE),
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
CiphertextModulus::new_native(),
|
||||
@@ -403,7 +436,7 @@ mod cuda {
|
||||
let mut ct_gpu = CudaLweCiphertextList::from_lwe_ciphertext(&ct, &streams);
|
||||
|
||||
let output_ct = LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
KeyswitchScalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
@@ -413,7 +446,10 @@ mod cuda {
|
||||
let h_indexes = [Scalar::ZERO];
|
||||
let cuda_indexes = CudaIndexes::new(&h_indexes, &streams, 0);
|
||||
|
||||
bench_id = format!("{bench_name}::{name}");
|
||||
bench_id = format!(
|
||||
"{bench_name}::latency::{:?}b::{name}",
|
||||
KeyswitchScalar::BITS
|
||||
);
|
||||
{
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
@@ -457,7 +493,8 @@ mod cuda {
|
||||
};
|
||||
let gemm_str = if uses_gemm_ks { "gemm" } else { "classical" };
|
||||
bench_id = format!(
|
||||
"{bench_name}::throughput::{gemm_str}::{indices_str}_indices::{name}",
|
||||
"{bench_name}::throughput::{:?}b::{gemm_str}::{indices_str}_indices::{name}",
|
||||
KeyswitchScalar::BITS
|
||||
);
|
||||
|
||||
let blocks: usize = 256;
|
||||
@@ -483,7 +520,7 @@ mod cuda {
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&big_lwe_sk,
|
||||
&big_lwe_sk_64,
|
||||
&mut input_ct_list,
|
||||
&plaintext_list,
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
@@ -504,10 +541,10 @@ mod cuda {
|
||||
let output_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let output_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
KeyswitchScalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
ciphertext_modulus_out,
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&output_ct_list,
|
||||
@@ -584,7 +621,7 @@ mod cuda {
|
||||
}
|
||||
|
||||
fn cuda_packing_keyswitch<
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize,
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize + CastInto<u32>,
|
||||
>(
|
||||
criterion: &mut Criterion,
|
||||
parameters: &[(String, CryptoParametersRecord<Scalar>)],
|
||||
@@ -791,9 +828,9 @@ mod cuda {
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(
|
||||
|(
|
||||
((i, input_lwe_list), output_glwe_list),
|
||||
local_stream,
|
||||
)| {
|
||||
((i, input_lwe_list), output_glwe_list),
|
||||
local_stream,
|
||||
)| {
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
|
||||
gpu_keys_vec[i].pksk.as_ref().unwrap(),
|
||||
input_lwe_list,
|
||||
@@ -826,7 +863,8 @@ mod cuda {
|
||||
let mut criterion: Criterion<_> = (Criterion::default().sample_size(15))
|
||||
.measurement_time(std::time::Duration::from_secs(60))
|
||||
.configure_from_args();
|
||||
cuda_keyswitch(&mut criterion, &benchmark_parameters());
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &benchmark_parameters());
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &benchmark_parameters());
|
||||
cuda_packing_keyswitch(&mut criterion, &benchmark_parameters());
|
||||
}
|
||||
|
||||
@@ -834,7 +872,8 @@ mod cuda {
|
||||
let mut criterion: Criterion<_> = (Criterion::default().sample_size(15))
|
||||
.measurement_time(std::time::Duration::from_secs(60))
|
||||
.configure_from_args();
|
||||
cuda_keyswitch(&mut criterion, &benchmark_parameters());
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &benchmark_parameters());
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &benchmark_parameters());
|
||||
}
|
||||
|
||||
pub fn cuda_multi_bit_ks_group() {
|
||||
@@ -844,7 +883,8 @@ mod cuda {
|
||||
.into_iter()
|
||||
.map(|(string, params, _)| (string, params))
|
||||
.collect_vec();
|
||||
cuda_keyswitch(&mut criterion, &multi_bit_parameters);
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &multi_bit_parameters);
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &multi_bit_parameters);
|
||||
cuda_packing_keyswitch(&mut criterion, &multi_bit_parameters);
|
||||
}
|
||||
|
||||
@@ -855,7 +895,8 @@ mod cuda {
|
||||
.into_iter()
|
||||
.map(|(string, params, _)| (string, params))
|
||||
.collect_vec();
|
||||
cuda_keyswitch(&mut criterion, &multi_bit_parameters);
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u32>(&mut criterion, &multi_bit_parameters);
|
||||
cuda_keyswitch_classical_and_gemm::<u64, u64>(&mut criterion, &multi_bit_parameters);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user