mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
chore(gpu): skip packing ks test if it needs more ram than available
This commit is contained in:
committed by
Agnès Leroy
parent
55179c52a7
commit
7986e0bf1d
@@ -331,10 +331,13 @@ mod cuda {
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
||||||
use tfhe::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
use tfhe::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||||
|
use tfhe::core_crypto::gpu::vec::GpuIndex;
|
||||||
use tfhe::core_crypto::gpu::{
|
use tfhe::core_crypto::gpu::{
|
||||||
cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64,
|
check_valid_cuda_malloc, cuda_keyswitch_lwe_ciphertext,
|
||||||
get_number_of_gpus, CudaStreams,
|
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64, get_number_of_gpus,
|
||||||
|
get_packing_keyswitch_list_64_size_on_gpu, CudaStreams,
|
||||||
};
|
};
|
||||||
|
|
||||||
use tfhe::core_crypto::prelude::*;
|
use tfhe::core_crypto::prelude::*;
|
||||||
|
|
||||||
fn cuda_keyswitch<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize>(
|
fn cuda_keyswitch<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize>(
|
||||||
@@ -588,10 +591,24 @@ mod cuda {
|
|||||||
let cpu_keys: CpuKeys<_> = CpuKeysBuilder::new().packing_keyswitch_key(pksk).build();
|
let cpu_keys: CpuKeys<_> = CpuKeysBuilder::new().packing_keyswitch_key(pksk).build();
|
||||||
|
|
||||||
let bench_id;
|
let bench_id;
|
||||||
|
|
||||||
match get_bench_type() {
|
match get_bench_type() {
|
||||||
BenchmarkType::Latency => {
|
BenchmarkType::Latency => {
|
||||||
let streams = CudaStreams::new_multi_gpu();
|
let streams = CudaStreams::new_multi_gpu();
|
||||||
|
|
||||||
|
let mem_size = get_packing_keyswitch_list_64_size_on_gpu(
|
||||||
|
&streams,
|
||||||
|
lwe_sk.lwe_dimension(),
|
||||||
|
glwe_sk.glwe_dimension(),
|
||||||
|
glwe_sk.polynomial_size(),
|
||||||
|
LweCiphertextCount(glwe_sk.polynomial_size().0),
|
||||||
|
);
|
||||||
|
|
||||||
|
let skip_bench = !check_valid_cuda_malloc(mem_size, GpuIndex::new(0));
|
||||||
|
|
||||||
|
if skip_bench {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let gpu_keys = CudaLocalKeys::from_cpu_keys(&cpu_keys, None, &streams);
|
let gpu_keys = CudaLocalKeys::from_cpu_keys(&cpu_keys, None, &streams);
|
||||||
|
|
||||||
let mut input_ct_list = LweCiphertextList::new(
|
let mut input_ct_list = LweCiphertextList::new(
|
||||||
@@ -647,9 +664,30 @@ mod cuda {
|
|||||||
let gpu_count = get_number_of_gpus() as usize;
|
let gpu_count = get_number_of_gpus() as usize;
|
||||||
|
|
||||||
bench_id = format!("{bench_name}::throughput::{name}");
|
bench_id = format!("{bench_name}::throughput::{name}");
|
||||||
|
|
||||||
|
let mem_size = get_packing_keyswitch_list_64_size_on_gpu(
|
||||||
|
&CudaStreams::new_single_gpu(GpuIndex::new(0)),
|
||||||
|
lwe_sk.lwe_dimension(),
|
||||||
|
glwe_sk.glwe_dimension(),
|
||||||
|
glwe_sk.polynomial_size(),
|
||||||
|
LweCiphertextCount(glwe_sk.polynomial_size().0),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut skip_test = false;
|
||||||
|
for gpu_index in 0..gpu_count {
|
||||||
|
if !check_valid_cuda_malloc(mem_size, GpuIndex::new(gpu_index as u32)) {
|
||||||
|
skip_test = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if skip_test {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let blocks: usize = 1;
|
let blocks: usize = 1;
|
||||||
let elements = throughput_num_threads(blocks, 1);
|
let elements = throughput_num_threads(blocks, 1);
|
||||||
let elements_per_stream = elements as usize / gpu_count;
|
let elements_per_stream =
|
||||||
|
std::cmp::min(elements as usize / gpu_count, glwe_sk.polynomial_size().0);
|
||||||
bench_group.throughput(Throughput::Elements(elements));
|
bench_group.throughput(Throughput::Elements(elements));
|
||||||
bench_group.sample_size(50);
|
bench_group.sample_size(50);
|
||||||
bench_group.bench_function(&bench_id, |b| {
|
bench_group.bench_function(&bench_id, |b| {
|
||||||
@@ -666,7 +704,7 @@ mod cuda {
|
|||||||
let mut input_ct_list = LweCiphertextList::new(
|
let mut input_ct_list = LweCiphertextList::new(
|
||||||
Scalar::ZERO,
|
Scalar::ZERO,
|
||||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||||
LweCiphertextCount(glwe_sk.polynomial_size().0),
|
LweCiphertextCount(elements_per_stream),
|
||||||
ciphertext_modulus,
|
ciphertext_modulus,
|
||||||
);
|
);
|
||||||
encrypt_lwe_ciphertext_list(
|
encrypt_lwe_ciphertext_list(
|
||||||
@@ -743,8 +781,9 @@ mod cuda {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn cuda_ks_group() {
|
pub fn cuda_ks_group() {
|
||||||
let mut criterion: Criterion<_> =
|
let mut criterion: Criterion<_> = (Criterion::default().sample_size(15))
|
||||||
(Criterion::default().sample_size(2000)).configure_from_args();
|
.measurement_time(std::time::Duration::from_secs(60))
|
||||||
|
.configure_from_args();
|
||||||
cuda_keyswitch(&mut criterion, &benchmark_parameters());
|
cuda_keyswitch(&mut criterion, &benchmark_parameters());
|
||||||
cuda_packing_keyswitch(&mut criterion, &benchmark_parameters());
|
cuda_packing_keyswitch(&mut criterion, &benchmark_parameters());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1062,6 +1062,37 @@ pub fn is_cuda_available() -> bool {
|
|||||||
result == 1u32
|
result == 1u32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_packing_keyswitch_list_64_size_on_gpu(
|
||||||
|
streams: &CudaStreams,
|
||||||
|
input_lwe_dimension: LweDimension,
|
||||||
|
output_glwe_dimension: GlweDimension,
|
||||||
|
output_polynomial_size: PolynomialSize,
|
||||||
|
num_lwes: LweCiphertextCount,
|
||||||
|
) -> u64 {
|
||||||
|
let mut fp_ks_buffer: *mut i8 = std::ptr::null_mut();
|
||||||
|
let size_tracker = unsafe {
|
||||||
|
scratch_packing_keyswitch_lwe_list_to_glwe_64(
|
||||||
|
streams.ptr[0],
|
||||||
|
streams.gpu_indexes[0].get(),
|
||||||
|
std::ptr::addr_of_mut!(fp_ks_buffer),
|
||||||
|
input_lwe_dimension.0 as u32,
|
||||||
|
output_glwe_dimension.0 as u32,
|
||||||
|
output_polynomial_size.0 as u32,
|
||||||
|
num_lwes.0 as u32,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
unsafe {
|
||||||
|
cleanup_packing_keyswitch_lwe_list_to_glwe(
|
||||||
|
streams.ptr[0],
|
||||||
|
streams.gpu_indexes[0].get(),
|
||||||
|
std::ptr::addr_of_mut!(fp_ks_buffer),
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
size_tracker
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
Reference in New Issue
Block a user