mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
feat(gpu): add support for GPU-accelerated expand on the HL Api
- includes documentation about GPU's accelerated expand on the HL API - rework CudaKeySwitchingKey - Cloning the key is no longer necessary on the HL API
This commit is contained in:
@@ -418,11 +418,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
|
||||
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
|
||||
mod cuda {
|
||||
use super::*;
|
||||
use benchmark::utilities::{cuda_local_keys, cuda_local_streams};
|
||||
use benchmark::utilities::cuda_local_streams;
|
||||
use criterion::BatchSize;
|
||||
use itertools::Itertools;
|
||||
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
|
||||
use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey;
|
||||
use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial};
|
||||
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
|
||||
use tfhe::integer::gpu::CudaServerKey;
|
||||
use tfhe::integer::CompressedServerKey;
|
||||
@@ -451,14 +451,17 @@ mod cuda {
|
||||
let param_name = param_name.as_str();
|
||||
let cks = ClientKey::new(param_fhe);
|
||||
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
let sk = compressed_server_key.decompress();
|
||||
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
|
||||
|
||||
let compact_private_key = CompactPrivateKey::new(param_pke);
|
||||
let pk = CompactPublicKey::new(&compact_private_key);
|
||||
let d_ksk = CudaKeySwitchingKey::new(
|
||||
(&compact_private_key, None),
|
||||
(&cks, &gpu_sks),
|
||||
param_ksk,
|
||||
&streams,
|
||||
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
|
||||
let d_ksk_material =
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
|
||||
&d_ksk_material,
|
||||
&gpu_sks,
|
||||
);
|
||||
|
||||
// We have a use case with 320 bits of metadata
|
||||
@@ -609,7 +612,6 @@ mod cuda {
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let gpu_sks_vec = cuda_local_keys(&cks);
|
||||
let gpu_count = get_number_of_gpus() as usize;
|
||||
|
||||
let elements = zk_throughput_num_elements();
|
||||
@@ -637,20 +639,17 @@ mod cuda {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
||||
let d_ksk_vec = gpu_sks_vec
|
||||
let d_ksk_material_vec = local_streams
|
||||
.par_iter()
|
||||
.zip(local_streams.par_iter())
|
||||
.map(|(gpu_sks, local_stream)| {
|
||||
CudaKeySwitchingKey::new(
|
||||
(&compact_private_key, None),
|
||||
(&cks, gpu_sks),
|
||||
param_ksk,
|
||||
.map(|local_stream| {
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(
|
||||
&ksk,
|
||||
local_stream,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(d_ksk_vec.len(), gpu_count);
|
||||
assert_eq!(d_ksk_material_vec.len(), gpu_count);
|
||||
|
||||
bench_group.bench_function(&bench_id_verify, |b| {
|
||||
b.iter(|| {
|
||||
@@ -673,14 +672,16 @@ mod cuda {
|
||||
(gpu_cts, local_streams)
|
||||
};
|
||||
|
||||
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter()
|
||||
.zip(local_streams.par_iter())
|
||||
.enumerate()
|
||||
.for_each(|(i, (gpu_ct, local_stream))| {
|
||||
gpu_ct
|
||||
.expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream)
|
||||
.unwrap();
|
||||
b.iter_batched(setup_encrypted_values,
|
||||
|(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
|
||||
(|(i, (gpu_ct, local_stream))| {
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
|
||||
|
||||
gpu_ct
|
||||
.expand_without_verification(&d_ksk, local_stream)
|
||||
.unwrap();
|
||||
});
|
||||
}, BatchSize::SmallInput);
|
||||
});
|
||||
@@ -698,16 +699,15 @@ mod cuda {
|
||||
(gpu_cts, local_streams)
|
||||
};
|
||||
|
||||
b.iter_batched(setup_encrypted_values, |(gpu_cts, local_streams)| {
|
||||
gpu_cts
|
||||
.par_iter()
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(|(gpu_ct, local_stream)| {
|
||||
gpu_ct
|
||||
.verify_and_expand(
|
||||
&crs, &pk, &metadata, &d_ksk, local_stream
|
||||
)
|
||||
.unwrap();
|
||||
b.iter_batched(setup_encrypted_values,
|
||||
|(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
|
||||
(|(gpu_ct, local_stream)| {
|
||||
gpu_ct
|
||||
.verify_and_expand(
|
||||
&crs, &pk, &metadata, &d_ksk, local_stream,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
}, BatchSize::SmallInput);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user