mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): add a benchmark for keyswitch on GPU
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# Run PBS benchmarks on an AWS instance and return parsed results to Slab CI bot.
|
||||
name: PBS benchmarks
|
||||
# Run core crypto benchmarks on an AWS instance and return parsed results to Slab CI bot.
|
||||
name: Core crypto benchmarks
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -35,8 +35,8 @@ env:
|
||||
RUST_BACKTRACE: "full"
|
||||
|
||||
jobs:
|
||||
run-pbs-benchmarks:
|
||||
name: Execute PBS benchmarks in EC2
|
||||
run-core-crypto-benchmarks:
|
||||
name: Execute core crypto benchmarks in EC2
|
||||
runs-on: ${{ github.event.inputs.runner_name }}
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
@@ -69,6 +69,7 @@ jobs:
|
||||
- name: Run benchmarks with AVX512
|
||||
run: |
|
||||
make AVX512_SUPPORT=ON bench_pbs
|
||||
make AVX512_SUPPORT=ON bench_ks
|
||||
|
||||
- name: Parse results
|
||||
run: |
|
||||
@@ -1,5 +1,5 @@
|
||||
# Run PBS benchmarks on an AWS instance with CUDA and return parsed results to Slab CI bot.
|
||||
name: PBS GPU benchmarks
|
||||
# Run core crypto benchmarks on an AWS instance with CUDA and return parsed results to Slab CI bot.
|
||||
name: Core crypto GPU benchmarks
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -34,8 +34,8 @@ env:
|
||||
ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
|
||||
jobs:
|
||||
run-pbs-benchmarks:
|
||||
name: Execute PBS benchmarks in EC2
|
||||
run-core-crypto-benchmarks:
|
||||
name: Execute GPU core crypto benchmarks in EC2
|
||||
runs-on: ${{ github.event.inputs.runner_name }}
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
@@ -85,6 +85,7 @@ jobs:
|
||||
- name: Run benchmarks with AVX512
|
||||
run: |
|
||||
make AVX512_SUPPORT=ON bench_pbs_gpu
|
||||
make AVX512_SUPPORT=ON bench_ks_gpu
|
||||
|
||||
- name: Parse results
|
||||
run: |
|
||||
12
Makefile
12
Makefile
@@ -726,6 +726,18 @@ bench_pbs_gpu: install_rs_check_toolchain
|
||||
--bench pbs-bench \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,gpu,internal-keycache,$(AVX512_FEATURE) -p $(TFHE_SPEC)
|
||||
|
||||
.PHONY: bench_ks # Run benchmarks for keyswitch
|
||||
bench_ks: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench ks-bench \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,internal-keycache,$(AVX512_FEATURE) -p $(TFHE_SPEC)
|
||||
|
||||
.PHONY: bench_ks_gpu # Run benchmarks for PBS on GPU backend
|
||||
bench_ks_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench ks-bench \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,gpu,internal-keycache,$(AVX512_FEATURE) -p $(TFHE_SPEC)
|
||||
|
||||
.PHONY: bench_web_js_api_parallel # Run benchmarks for the web wasm api
|
||||
bench_web_js_api_parallel: build_web_js_api_parallel
|
||||
$(MAKE) -C tfhe/web_wasm_parallel_tests bench
|
||||
|
||||
@@ -1,49 +1,72 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
#[path = "../utilities.rs"]
|
||||
mod utilities;
|
||||
use crate::utilities::{write_to_json, CryptoParametersRecord, OperatorType};
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use serde::Serialize;
|
||||
use tfhe::boolean::prelude::*;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::prelude::*;
|
||||
use tfhe::shortint::PBSParameters;
|
||||
|
||||
fn criterion_bench(criterion: &mut Criterion) {
|
||||
type Scalar = u64;
|
||||
const SHORTINT_BENCH_PARAMS: [ClassicPBSParameters; 4] = [
|
||||
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
PARAM_MESSAGE_3_CARRY_3_KS_PBS,
|
||||
PARAM_MESSAGE_4_CARRY_4_KS_PBS,
|
||||
];
|
||||
|
||||
let mut bench_group = criterion.benchmark_group("KS");
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
const BOOLEAN_BENCH_PARAMS: [(&str, BooleanParameters); 2] = [
|
||||
("BOOLEAN_DEFAULT_PARAMS", DEFAULT_PARAMETERS),
|
||||
(
|
||||
"BOOLEAN_TFHE_LIB_PARAMS",
|
||||
PARAMETERS_ERROR_PROB_2_POW_MINUS_165,
|
||||
),
|
||||
];
|
||||
|
||||
for params in [
|
||||
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
PARAM_MESSAGE_3_CARRY_3_KS_PBS,
|
||||
PARAM_MESSAGE_4_CARRY_4_KS_PBS,
|
||||
]
|
||||
.into_iter()
|
||||
{
|
||||
let lwe_dimension = params.lwe_dimension;
|
||||
let lwe_modular_std_dev = params.lwe_modular_std_dev;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let encoding_with_padding = if ciphertext_modulus.is_native_modulus() {
|
||||
Scalar::ONE << (Scalar::BITS - 1)
|
||||
} else {
|
||||
Scalar::cast_from(ciphertext_modulus.get_custom_modulus() / 2)
|
||||
};
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let ks_decomp_base_log = params.ks_base_log;
|
||||
let ks_decomp_level_count = params.ks_level;
|
||||
let msg_modulus: Scalar = params.message_modulus.0.cast_into();
|
||||
let total_modulus: Scalar = (params.message_modulus.0 * params.carry_modulus.0).cast_into();
|
||||
fn benchmark_parameters<Scalar: UnsignedInteger>() -> Vec<(String, CryptoParametersRecord<Scalar>)>
|
||||
{
|
||||
if Scalar::BITS == 64 {
|
||||
SHORTINT_BENCH_PARAMS
|
||||
.iter()
|
||||
.map(|params| {
|
||||
(
|
||||
params.name(),
|
||||
<ClassicPBSParameters as Into<PBSParameters>>::into(*params)
|
||||
.to_owned()
|
||||
.into(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
} else if Scalar::BITS == 32 {
|
||||
BOOLEAN_BENCH_PARAMS
|
||||
.iter()
|
||||
.map(|(name, params)| (name.to_string(), params.to_owned().into()))
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
let msg = msg_modulus - 1;
|
||||
let delta: Scalar = encoding_with_padding / total_modulus;
|
||||
fn keyswitch<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(criterion: &mut Criterion) {
|
||||
let bench_name = "core_crypto::keyswitch";
|
||||
let mut bench_group = criterion.benchmark_group(bench_name);
|
||||
|
||||
// Create the PRNG
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
let mut secret_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
// Create the PRNG
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
let mut secret_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
|
||||
for (name, params) in benchmark_parameters::<Scalar>().iter() {
|
||||
let lwe_dimension = params.lwe_dimension.unwrap();
|
||||
let lwe_modular_std_dev = params.lwe_modular_std_dev.unwrap();
|
||||
let glwe_dimension = params.glwe_dimension.unwrap();
|
||||
let polynomial_size = params.polynomial_size.unwrap();
|
||||
let ks_decomp_base_log = params.ks_base_log.unwrap();
|
||||
let ks_decomp_level_count = params.ks_level.unwrap();
|
||||
|
||||
let lwe_sk =
|
||||
allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
|
||||
@@ -60,32 +83,176 @@ fn criterion_bench(criterion: &mut Criterion) {
|
||||
ks_decomp_base_log,
|
||||
ks_decomp_level_count,
|
||||
lwe_modular_std_dev,
|
||||
ciphertext_modulus,
|
||||
tfhe::core_crypto::prelude::CiphertextModulus::new_native(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
plaintext,
|
||||
Plaintext(Scalar::ONE),
|
||||
lwe_modular_std_dev,
|
||||
ciphertext_modulus,
|
||||
tfhe::core_crypto::prelude::CiphertextModulus::new_native(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let mut output_ct = LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
tfhe::core_crypto::prelude::CiphertextModulus::new_native(),
|
||||
);
|
||||
|
||||
bench_group.bench_function(¶ms.name(), |bencher| {
|
||||
bencher.iter(|| {
|
||||
keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut output_ct);
|
||||
})
|
||||
});
|
||||
let id = format!("{bench_name}_{name}");
|
||||
{
|
||||
bench_group.bench_function(&id, |b| {
|
||||
b.iter(|| {
|
||||
keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut output_ct);
|
||||
black_box(&mut output_ct);
|
||||
})
|
||||
});
|
||||
}
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_bench);
|
||||
criterion_main!(benches);
|
||||
#[cfg(feature = "gpu")]
|
||||
mod cuda {
|
||||
use crate::benchmark_parameters;
|
||||
use crate::utilities::{write_to_json, OperatorType};
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
use serde::Serialize;
|
||||
use tfhe::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use tfhe::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use tfhe::core_crypto::gpu::{cuda_keyswitch_lwe_ciphertext, CudaDevice, CudaStream};
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
use tfhe::keycache::NamedParam;
|
||||
|
||||
fn cuda_keyswitch<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
|
||||
criterion: &mut Criterion,
|
||||
) {
|
||||
let bench_name = "core_crypto::cuda::keyswitch";
|
||||
let mut bench_group = criterion.benchmark_group(bench_name);
|
||||
|
||||
// Create the PRNG
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
let mut secret_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
|
||||
let gpu_index = 0;
|
||||
let device = CudaDevice::new(gpu_index);
|
||||
let stream = CudaStream::new_unchecked(device);
|
||||
|
||||
for (name, params) in benchmark_parameters::<Scalar>().iter() {
|
||||
let lwe_dimension = params.lwe_dimension.unwrap();
|
||||
let lwe_modular_std_dev = params.lwe_modular_std_dev.unwrap();
|
||||
let glwe_dimension = params.glwe_dimension.unwrap();
|
||||
let polynomial_size = params.polynomial_size.unwrap();
|
||||
let ks_decomp_base_log = params.ks_base_log.unwrap();
|
||||
let ks_decomp_level_count = params.ks_level.unwrap();
|
||||
|
||||
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut secret_generator,
|
||||
);
|
||||
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
|
||||
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
&big_lwe_sk,
|
||||
&lwe_sk,
|
||||
ks_decomp_base_log,
|
||||
ks_decomp_level_count,
|
||||
lwe_modular_std_dev,
|
||||
CiphertextModulus::new_native(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
let ksk_big_to_small_gpu =
|
||||
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&ksk_big_to_small, &stream);
|
||||
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
Plaintext(Scalar::ONE),
|
||||
lwe_modular_std_dev,
|
||||
CiphertextModulus::new_native(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
let mut ct_gpu = CudaLweCiphertextList::from_lwe_ciphertext(&ct, &stream);
|
||||
|
||||
let output_ct = LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
let mut output_ct_gpu = CudaLweCiphertextList::from_lwe_ciphertext(&output_ct, &stream);
|
||||
|
||||
let h_indexes = &[Scalar::ZERO];
|
||||
let mut d_input_indexes = unsafe { stream.malloc_async::<Scalar>(1u32) };
|
||||
let mut d_output_indexes = unsafe { stream.malloc_async::<Scalar>(1u32) };
|
||||
unsafe {
|
||||
stream.copy_to_gpu_async(&mut d_input_indexes, h_indexes.as_ref());
|
||||
stream.copy_to_gpu_async(&mut d_output_indexes, h_indexes.as_ref());
|
||||
}
|
||||
stream.synchronize();
|
||||
|
||||
let id = format!("{bench_name}_{name}");
|
||||
{
|
||||
bench_group.bench_function(&id, |b| {
|
||||
b.iter(|| {
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
&ksk_big_to_small_gpu,
|
||||
&ct_gpu,
|
||||
&mut output_ct_gpu,
|
||||
&d_input_indexes,
|
||||
&d_output_indexes,
|
||||
&stream,
|
||||
);
|
||||
black_box(&mut ct_gpu);
|
||||
})
|
||||
});
|
||||
}
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
}
|
||||
criterion_group!(
|
||||
name = cuda_keyswitch_group;
|
||||
config = Criterion::default().sample_size(2000);
|
||||
targets = cuda_keyswitch::<u64>, cuda_keyswitch::<u32>
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use cuda::cuda_keyswitch_group;
|
||||
|
||||
criterion_group!(
|
||||
name = keyswitch_group;
|
||||
config = Criterion::default().sample_size(2000);
|
||||
targets = keyswitch::<u64>, keyswitch::<u32>
|
||||
);
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
criterion_main!(keyswitch_group);
|
||||
#[cfg(feature = "gpu")]
|
||||
criterion_main!(cuda_keyswitch_group);
|
||||
|
||||
@@ -630,7 +630,7 @@ mod cuda {
|
||||
unsafe {
|
||||
stream.copy_to_gpu_async(&mut d_input_indexes, h_indexes.as_ref());
|
||||
stream.copy_to_gpu_async(&mut d_output_indexes, h_indexes.as_ref());
|
||||
stream.copy_to_gpu_async(&mut d_input_indexes, h_indexes.as_ref());
|
||||
stream.copy_to_gpu_async(&mut d_lut_indexes, h_indexes.as_ref());
|
||||
}
|
||||
stream.synchronize();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user