mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
fix(gpu): fix compression throughput benchmark
This commit is contained in:
@@ -159,6 +159,7 @@ fn cpu_glwe_packing(c: &mut Criterion) {
|
||||
mod cuda {
|
||||
use super::*;
|
||||
use benchmark::utilities::cuda_integer_utils::cuda_local_streams;
|
||||
use itertools::Itertools;
|
||||
use std::cmp::max;
|
||||
use tfhe::core_crypto::gpu::CudaStreams;
|
||||
use tfhe::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextListBuilder;
|
||||
@@ -203,18 +204,20 @@ mod cuda {
|
||||
let (radix_cks, _) = gen_keys_radix_gpu(param, num_blocks, &stream);
|
||||
let (compressed_compression_key, compressed_decompression_key) =
|
||||
radix_cks.new_compressed_compression_decompression_keys(&private_compression_key);
|
||||
let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&stream);
|
||||
let cuda_decompression_key = compressed_decompression_key.decompress_to_cuda(
|
||||
radix_cks.parameters().glwe_dimension(),
|
||||
radix_cks.parameters().polynomial_size(),
|
||||
radix_cks.parameters().message_modulus(),
|
||||
radix_cks.parameters().carry_modulus(),
|
||||
radix_cks.parameters().ciphertext_modulus(),
|
||||
&stream,
|
||||
);
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let cuda_compression_key =
|
||||
compressed_compression_key.decompress_to_cuda(&stream);
|
||||
let cuda_decompression_key = compressed_decompression_key.decompress_to_cuda(
|
||||
radix_cks.parameters().glwe_dimension(),
|
||||
radix_cks.parameters().polynomial_size(),
|
||||
radix_cks.parameters().message_modulus(),
|
||||
radix_cks.parameters().carry_modulus(),
|
||||
radix_cks.parameters().ciphertext_modulus(),
|
||||
&stream,
|
||||
);
|
||||
|
||||
// Encrypt
|
||||
let ct = cks.encrypt_radix(0_u32, num_blocks);
|
||||
let d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &stream);
|
||||
@@ -268,59 +271,84 @@ mod cuda {
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
|
||||
// Encrypt
|
||||
let ct = cks.encrypt_radix(0_u32, num_blocks);
|
||||
let d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &stream);
|
||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
||||
|
||||
let cuda_compression_key_vec = local_streams
|
||||
.iter()
|
||||
.map(|local_stream| {
|
||||
compressed_compression_key.decompress_to_cuda(local_stream)
|
||||
})
|
||||
.collect_vec();
|
||||
let cuda_decompression_key_vec = local_streams
|
||||
.iter()
|
||||
.map(|local_stream| {
|
||||
compressed_decompression_key.decompress_to_cuda(
|
||||
radix_cks.parameters().glwe_dimension(),
|
||||
radix_cks.parameters().polynomial_size(),
|
||||
radix_cks.parameters().message_modulus(),
|
||||
radix_cks.parameters().carry_modulus(),
|
||||
radix_cks.parameters().ciphertext_modulus(),
|
||||
local_stream,
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// Benchmark
|
||||
let mut builder = CudaCompressedCiphertextListBuilder::new();
|
||||
|
||||
builder.push(d_ct, &stream);
|
||||
|
||||
let builders = (0..elements)
|
||||
.map(|_| {
|
||||
.map(|i| {
|
||||
let ct = cks.encrypt_radix(0_u32, num_blocks);
|
||||
let d_ct =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &stream);
|
||||
let local_stream = &local_streams[i as usize % local_streams.len()];
|
||||
let d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(
|
||||
&ct,
|
||||
local_stream,
|
||||
);
|
||||
let mut builder = CudaCompressedCiphertextListBuilder::new();
|
||||
builder.push(d_ct, &stream);
|
||||
builder.push(d_ct, local_stream);
|
||||
|
||||
builder
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
||||
|
||||
bench_id_pack = format!("{bench_name}::throughput::pack_u{bit_size}");
|
||||
bench_group.bench_function(&bench_id_pack, |b| {
|
||||
b.iter(|| {
|
||||
builders.par_iter().zip(local_streams.par_iter()).for_each(
|
||||
|(builder, local_stream)| {
|
||||
builder.build(&cuda_compression_key, local_stream);
|
||||
},
|
||||
)
|
||||
builders.par_iter().enumerate().for_each(|(i, builder)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
let cuda_compression_key =
|
||||
&cuda_compression_key_vec[i % local_streams.len()];
|
||||
|
||||
builder.build(cuda_compression_key, local_stream);
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
let compressed = builders
|
||||
.iter()
|
||||
.map(|builder| builder.build(&cuda_compression_key, &stream))
|
||||
.enumerate()
|
||||
.map(|(i, builder)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
let cuda_compression_key =
|
||||
&cuda_compression_key_vec[i % local_streams.len()];
|
||||
builder.build(cuda_compression_key, local_stream)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
bench_id_unpack = format!("{bench_name}::throughput::unpack_u{bit_size}");
|
||||
bench_group.bench_function(&bench_id_unpack, |b| {
|
||||
b.iter(|| {
|
||||
compressed
|
||||
.par_iter()
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(|(comp, local_stream)| {
|
||||
comp.get::<CudaUnsignedRadixCiphertext>(
|
||||
0,
|
||||
&cuda_decompression_key,
|
||||
local_stream,
|
||||
)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
})
|
||||
compressed.par_iter().enumerate().for_each(|(i, comp)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
let cuda_decompression_key =
|
||||
&cuda_decompression_key_vec[i % local_streams.len()];
|
||||
|
||||
comp.get::<CudaUnsignedRadixCiphertext>(
|
||||
0,
|
||||
cuda_decompression_key,
|
||||
local_stream,
|
||||
)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
})
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user