mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
chore(gpu): bench KS latency batches
This commit is contained in:
committed by
Andrei Stoian
parent
d6a0a366b9
commit
e2063c8ef4
@@ -386,7 +386,7 @@ mod cuda {
|
||||
.keyswitch_key(ksk_big_to_small)
|
||||
.build();
|
||||
|
||||
let bench_id;
|
||||
let mut bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
@@ -423,120 +423,163 @@ mod cuda {
|
||||
&mut output_ct_gpu,
|
||||
&cuda_indexes.d_input,
|
||||
&cuda_indexes.d_output,
|
||||
true,
|
||||
&streams,
|
||||
false,
|
||||
);
|
||||
|
||||
black_box(&mut ct_gpu);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let gpu_keys_vec = cuda_local_keys_core(&cpu_keys, None);
|
||||
let gpu_count = get_number_of_gpus() as usize;
|
||||
|
||||
bench_id = format!("{bench_name}::throughput::{name}");
|
||||
let blocks: usize = 1;
|
||||
let elements = throughput_num_threads(blocks, 1);
|
||||
let elements_per_stream = elements as usize / gpu_count;
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
bench_group.sample_size(50);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let setup_encrypted_values = || {
|
||||
let local_streams = cuda_local_streams_core();
|
||||
|
||||
let plaintext_list = PlaintextList::new(
|
||||
Scalar::ZERO,
|
||||
PlaintextCount(elements_per_stream),
|
||||
for uses_gemm_ks in [false, true] {
|
||||
for uses_simple_indices in [false, true] {
|
||||
let indices_str = if uses_simple_indices {
|
||||
"simple"
|
||||
} else {
|
||||
"complex"
|
||||
};
|
||||
let gemm_str = if uses_gemm_ks { "gemm" } else { "classical" };
|
||||
bench_id = format!(
|
||||
"{bench_name}::throughput::{gemm_str}::{indices_str}_indices::{name}",
|
||||
);
|
||||
|
||||
let input_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let mut input_ct_list = LweCiphertextList::new(
|
||||
let blocks: usize = 256;
|
||||
let elements = gpu_count * blocks;
|
||||
let elements_per_stream = elements / gpu_count;
|
||||
bench_group.throughput(Throughput::Elements(elements as u64));
|
||||
bench_group.sample_size(50);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let setup_encrypted_values = || {
|
||||
let local_streams = cuda_local_streams_core();
|
||||
|
||||
let plaintext_list = PlaintextList::new(
|
||||
Scalar::ZERO,
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
PlaintextCount(elements_per_stream),
|
||||
);
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&big_lwe_sk,
|
||||
&mut input_ct_list,
|
||||
&plaintext_list,
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
let input_ks_list = LweCiphertextList::from_container(
|
||||
input_ct_list.into_container(),
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&input_ks_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let output_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&output_ct_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let input_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let mut input_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&big_lwe_sk,
|
||||
&mut input_ct_list,
|
||||
&plaintext_list,
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
let input_ks_list = LweCiphertextList::from_container(
|
||||
input_ct_list.into_container(),
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&input_ks_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let h_indexes = (0..(elements / gpu_count as u64))
|
||||
.map(CastFrom::cast_from)
|
||||
.collect::<Vec<_>>();
|
||||
let cuda_indexes_vec = (0..gpu_count)
|
||||
.map(|i| CudaIndexes::new(&h_indexes, &local_streams[i], 0))
|
||||
.collect::<Vec<_>>();
|
||||
local_streams.iter().for_each(|stream| stream.synchronize());
|
||||
let output_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let output_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&output_ct_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
(input_cts, output_cts, cuda_indexes_vec, local_streams)
|
||||
};
|
||||
let indexes_range: Vec<u64> = if uses_simple_indices {
|
||||
(0..(elements / gpu_count) as u64).collect()
|
||||
} else {
|
||||
(0..(elements / gpu_count) as u64).rev().collect()
|
||||
};
|
||||
let h_indexes = indexes_range
|
||||
.iter()
|
||||
.map(|v| CastFrom::cast_from(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
b.iter_batched(
|
||||
setup_encrypted_values,
|
||||
|(input_cts, mut output_cts, cuda_indexes_vec, local_streams)| {
|
||||
(0..gpu_count)
|
||||
.into_par_iter()
|
||||
.zip(input_cts.par_iter())
|
||||
.zip(output_cts.par_iter_mut())
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(|(((i, input_ct), output_ct), local_stream)| {
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
gpu_keys_vec[i].ksk.as_ref().unwrap(),
|
||||
input_ct,
|
||||
output_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
local_stream,
|
||||
);
|
||||
})
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
let cuda_indexes_vec = (0..gpu_count)
|
||||
.map(|i| CudaIndexes::new(&h_indexes, &local_streams[i], 0))
|
||||
.collect::<Vec<_>>();
|
||||
local_streams.iter().for_each(|stream| stream.synchronize());
|
||||
|
||||
(input_cts, output_cts, cuda_indexes_vec, local_streams)
|
||||
};
|
||||
|
||||
b.iter_batched(
|
||||
setup_encrypted_values,
|
||||
|(
|
||||
input_cts,
|
||||
mut output_cts,
|
||||
cuda_indexes_vec,
|
||||
local_streams,
|
||||
)| {
|
||||
(0..gpu_count)
|
||||
.into_par_iter()
|
||||
.zip(input_cts.par_iter())
|
||||
.zip(output_cts.par_iter_mut())
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(
|
||||
|(((i, input_ct), output_ct), local_stream)| {
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
gpu_keys_vec[i].ksk.as_ref().unwrap(),
|
||||
input_ct,
|
||||
output_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
uses_simple_indices,
|
||||
local_stream,
|
||||
uses_gemm_ks,
|
||||
);
|
||||
},
|
||||
)
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -630,7 +630,9 @@ mod cuda {
|
||||
&mut output_ks_ct_gpu,
|
||||
&cuda_indexes.d_input,
|
||||
&cuda_indexes.d_output,
|
||||
true,
|
||||
&streams,
|
||||
false,
|
||||
);
|
||||
cuda_programmable_bootstrap_lwe_ciphertext(
|
||||
&output_ks_ct_gpu,
|
||||
@@ -782,7 +784,9 @@ mod cuda {
|
||||
output_ks_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
true,
|
||||
local_stream,
|
||||
false,
|
||||
);
|
||||
cuda_programmable_bootstrap_lwe_ciphertext(
|
||||
output_ks_ct,
|
||||
@@ -937,7 +941,9 @@ mod cuda {
|
||||
&mut output_ks_ct_gpu,
|
||||
&cuda_indexes.d_input,
|
||||
&cuda_indexes.d_output,
|
||||
true,
|
||||
&streams,
|
||||
false,
|
||||
);
|
||||
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
&output_ks_ct_gpu,
|
||||
@@ -1088,7 +1094,9 @@ mod cuda {
|
||||
output_ks_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
true,
|
||||
local_stream,
|
||||
false,
|
||||
);
|
||||
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
output_ks_ct,
|
||||
|
||||
Reference in New Issue
Block a user