fix(gpu): Fix expand bench on multi-gpus

This commit is contained in:
Pedro Alves
2025-07-07 12:46:08 -03:00
committed by Agnès Leroy
parent 776f08b534
commit 9960f5e8b6
4 changed files with 172 additions and 146 deletions

View File

@@ -421,8 +421,6 @@ mod cuda {
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let streams = CudaStreams::new_multi_gpu();
File::create(results_file).expect("create results file failed");
let mut file = OpenOptions::new()
.append(true)
@@ -439,17 +437,10 @@ mod cuda {
let cks = ClientKey::new(param_fhe);
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
let sk = compressed_server_key.decompress();
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
let compact_private_key = CompactPrivateKey::new(param_pke);
let pk = CompactPublicKey::new(&compact_private_key);
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
&d_ksk_material,
&gpu_sks,
);
// We have a use case with 320 bits of metadata
let mut metadata = [0u8; (320 / u8::BITS) as usize];
@@ -509,6 +500,18 @@ mod cuda {
match get_bench_type() {
BenchmarkType::Latency => {
let streams = CudaStreams::new_multi_gpu();
let gpu_sks = CudaServerKey::decompress_from_cpu(
&compressed_server_key,
&streams,
);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
&d_ksk_material,
&gpu_sks,
);
bench_id_verify = format!(
"{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
);
@@ -599,9 +602,7 @@ mod cuda {
});
}
BenchmarkType::Throughput => {
let gpu_count = get_number_of_gpus() as usize;
let elements = zk_throughput_num_elements();
let elements = 100 * get_number_of_gpus() as u64; // This value, found empirically, ensure saturation of 8XH100 SXM5
bench_group.throughput(Throughput::Elements(elements));
bench_id_verify = format!(
@@ -636,8 +637,6 @@ mod cuda {
})
.collect::<Vec<_>>();
assert_eq!(d_ksk_material_vec.len(), gpu_count);
bench_group.bench_function(&bench_id_verify, |b| {
b.iter(|| {
cts.par_iter().for_each(|ct1| {
@@ -648,23 +647,25 @@ mod cuda {
bench_group.bench_function(&bench_id_expand_without_verify, |b| {
let setup_encrypted_values = || {
let local_streams = cuda_local_streams(num_block, elements as usize);
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
let local_stream = &local_streams[i % local_streams.len()];
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
ct, &local_streams[i],
ct, local_stream,
)
}).collect_vec();
(gpu_cts, local_streams)
gpu_cts
};
b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| {
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
(|(i, (gpu_ct, local_stream))| {
|gpu_cts| {
gpu_cts.par_iter().enumerate().for_each
(|(i, gpu_ct)| {
let local_stream = &local_streams[i % local_streams.len()];
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
gpu_ct
.expand_without_verification(&d_ksk, local_stream)
@@ -675,21 +676,24 @@ mod cuda {
bench_group.bench_function(&bench_id_verify_and_expand, |b| {
let setup_encrypted_values = || {
let local_streams = cuda_local_streams(num_block, elements as usize);
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
ct, &local_streams[i],
ct, &local_streams[i% local_streams.len()],
)
}).collect_vec();
(gpu_cts, local_streams)
gpu_cts
};
b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| {
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
(|(gpu_ct, local_stream)| {
|gpu_cts| {
gpu_cts.par_iter().enumerate().for_each
(|(i, gpu_ct)| {
let local_stream = &local_streams[i % local_streams.len()];
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
gpu_ct
.verify_and_expand(
&crs, &pk, &metadata, &d_ksk, local_stream,