mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
fix(gpu): Fix expand bench on multi-gpus
This commit is contained in:
@@ -421,8 +421,6 @@ mod cuda {
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
|
||||
File::create(results_file).expect("create results file failed");
|
||||
let mut file = OpenOptions::new()
|
||||
.append(true)
|
||||
@@ -439,17 +437,10 @@ mod cuda {
|
||||
let cks = ClientKey::new(param_fhe);
|
||||
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
let sk = compressed_server_key.decompress();
|
||||
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
|
||||
|
||||
let compact_private_key = CompactPrivateKey::new(param_pke);
|
||||
let pk = CompactPublicKey::new(&compact_private_key);
|
||||
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
|
||||
let d_ksk_material =
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
|
||||
&d_ksk_material,
|
||||
&gpu_sks,
|
||||
);
|
||||
|
||||
// We have a use case with 320 bits of metadata
|
||||
let mut metadata = [0u8; (320 / u8::BITS) as usize];
|
||||
@@ -509,6 +500,18 @@ mod cuda {
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
let gpu_sks = CudaServerKey::decompress_from_cpu(
|
||||
&compressed_server_key,
|
||||
&streams,
|
||||
);
|
||||
let d_ksk_material =
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
|
||||
&d_ksk_material,
|
||||
&gpu_sks,
|
||||
);
|
||||
|
||||
bench_id_verify = format!(
|
||||
"{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
|
||||
);
|
||||
@@ -599,9 +602,7 @@ mod cuda {
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let gpu_count = get_number_of_gpus() as usize;
|
||||
|
||||
let elements = zk_throughput_num_elements();
|
||||
let elements = 100 * get_number_of_gpus() as u64; // This value, found empirically, ensure saturation of 8XH100 SXM5
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
|
||||
bench_id_verify = format!(
|
||||
@@ -636,8 +637,6 @@ mod cuda {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(d_ksk_material_vec.len(), gpu_count);
|
||||
|
||||
bench_group.bench_function(&bench_id_verify, |b| {
|
||||
b.iter(|| {
|
||||
cts.par_iter().for_each(|ct1| {
|
||||
@@ -648,23 +647,25 @@ mod cuda {
|
||||
|
||||
bench_group.bench_function(&bench_id_expand_without_verify, |b| {
|
||||
let setup_encrypted_values = || {
|
||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
||||
|
||||
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
||||
ct, &local_streams[i],
|
||||
ct, local_stream,
|
||||
)
|
||||
}).collect_vec();
|
||||
|
||||
(gpu_cts, local_streams)
|
||||
gpu_cts
|
||||
};
|
||||
|
||||
b.iter_batched(setup_encrypted_values,
|
||||
|(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
|
||||
(|(i, (gpu_ct, local_stream))| {
|
||||
|gpu_cts| {
|
||||
gpu_cts.par_iter().enumerate().for_each
|
||||
(|(i, gpu_ct)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
|
||||
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
|
||||
|
||||
gpu_ct
|
||||
.expand_without_verification(&d_ksk, local_stream)
|
||||
@@ -675,21 +676,24 @@ mod cuda {
|
||||
|
||||
bench_group.bench_function(&bench_id_verify_and_expand, |b| {
|
||||
let setup_encrypted_values = || {
|
||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
||||
|
||||
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
|
||||
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
||||
ct, &local_streams[i],
|
||||
ct, &local_streams[i% local_streams.len()],
|
||||
)
|
||||
}).collect_vec();
|
||||
|
||||
(gpu_cts, local_streams)
|
||||
gpu_cts
|
||||
};
|
||||
|
||||
b.iter_batched(setup_encrypted_values,
|
||||
|(gpu_cts, local_streams)| {
|
||||
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
|
||||
(|(gpu_ct, local_stream)| {
|
||||
|gpu_cts| {
|
||||
gpu_cts.par_iter().enumerate().for_each
|
||||
(|(i, gpu_ct)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
|
||||
|
||||
gpu_ct
|
||||
.verify_and_expand(
|
||||
&crs, &pk, &metadata, &d_ksk, local_stream,
|
||||
|
||||
Reference in New Issue
Block a user