WIP: put back zk proof bench and run only 256 bits proof

This commit is contained in:
David Testé
2025-11-17 15:03:01 +01:00
parent 41fabf79e4
commit f63fd7c0e8
2 changed files with 139 additions and 134 deletions

View File

@@ -38,7 +38,8 @@ fn default_proof_config() -> Vec<ProofConfig> {
vec![
// ProofConfig::new(64usize, &[64usize]),
// ProofConfig::new(640, &[640]),
ProofConfig::new(2048, &[2048, 4 * 64usize]),
// ProofConfig::new(2048, &[2048, 4 * 64usize]),
ProofConfig::new(2048, &[4 * 64usize]),
// ProofConfig::new(4096, &[4096]),
]
}
@@ -67,135 +68,139 @@ fn zk_throughput_num_elements() -> u64 {
}
}
// fn cpu_pke_zk_proof(c: &mut Criterion) {
// let bench_name = "zk::pke_zk_proof";
// let mut bench_group = c.benchmark_group(bench_name);
// bench_group
// .sample_size(15)
// .measurement_time(std::time::Duration::from_secs(60));
//
// for (param_pke, _param_casting, param_fhe) in [
// (
// BENCH_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
// BENCH_PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
// BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
// ),
// (
// BENCH_PARAM_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1,
// BENCH_PARAM_KEYSWITCH_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1,
// BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
// ),
// ] {
// let param_name = param_fhe.name();
// let param_name = param_name.as_str();
// let cks = ClientKey::new(param_fhe);
// let sks = ServerKey::new_radix_server_key(&cks);
// let compact_private_key = CompactPrivateKey::new(param_pke);
// let pk = CompactPublicKey::new(&compact_private_key);
// // Kept for consistency
// let _casting_key =
// KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), _param_casting);
//
// // We have a use case with 320 bits of metadata
// let mut metadata = [0u8; (320 / u8::BITS) as usize];
// let mut rng = rand::thread_rng();
// metadata.fill_with(|| rng.gen());
//
// let zk_vers = param_pke.zk_scheme;
//
// for proof_config in default_proof_config().iter() {
// let msg_bits =
// (param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize;
// println!("Generating CRS... ");
// let crs = CompactPkeCrs::from_shortint_params(
// param_pke,
// LweCiphertextCount(proof_config.crs_size / msg_bits),
// )
// .unwrap();
//
// for bits in proof_config.bits_to_prove.iter() {
// assert_eq!(bits % 64, 0);
// // Packing, so we take the message and carry modulus to compute our block count
// let num_block = 64usize.div_ceil(
// (param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize,
// );
//
// let fhe_uint_count = bits / 64;
//
// for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] {
// let zk_load = match compute_load {
// ZkComputeLoad::Proof => "compute_load_proof",
// ZkComputeLoad::Verify => "compute_load_verify",
// };
//
// let bench_id;
//
// match get_bench_type() {
// BenchmarkType::Latency => {
// bench_id = format!(
//
// "{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
// ); bench_group.bench_function(&bench_id, |b| {
// let input_msg = rng.gen::<u64>();
// let messages = vec![input_msg; fhe_uint_count];
//
// b.iter(|| {
// let _ct1 =
// tfhe::integer::ProvenCompactCiphertextList::builder(&pk)
// .extend(messages.iter().copied())
// .build_with_proof_packed(&crs, &metadata,
// compute_load) .unwrap();
// })
// });
// }
// BenchmarkType::Throughput => {
// let elements = zk_throughput_num_elements() * 2; // This value, found
// empirically, ensure saturation of current target // machine
// bench_group.throughput(Throughput::Elements(elements));
//
// bench_id = format!(
//
// "{bench_name}::throughput::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
// ); bench_group.bench_function(&bench_id, |b| {
// let messages = (0..elements)
// .map(|_| {
// let input_msg = rng.gen::<u64>();
// vec![input_msg; fhe_uint_count]
// })
// .collect::<Vec<_>>();
//
// b.iter(|| {
// messages.par_iter().for_each(|msg| {
// tfhe::integer::ProvenCompactCiphertextList::builder(&pk)
// .extend(msg.iter().copied())
// .build_with_proof_packed(&crs, &metadata,
// compute_load) .unwrap();
// })
// })
// });
// }
// }
//
// let shortint_params: PBSParameters = param_fhe.into();
//
// write_to_json::<u64, _>(
// &bench_id,
// shortint_params,
// param_name,
// "pke_zk_proof",
// &OperatorType::Atomic,
// shortint_params.message_modulus().0 as u32,
// vec![shortint_params.message_modulus().0.ilog2(); num_block],
// );
// }
// }
// }
// }
//
// bench_group.finish()
// }
//
fn cpu_pke_zk_proof(c: &mut Criterion) {
let bench_name = "zk::pke_zk_proof";
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
for (param_pke, _param_casting, param_fhe) in [
(
V0_11_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
V0_11_PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
),
// (
// BENCH_PARAM_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1,
// BENCH_PARAM_KEYSWITCH_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1,
// BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
// ),
] {
let param_name = param_fhe.name();
let param_name = param_name.as_str();
let cks = ClientKey::new(param_fhe);
let sks = ServerKey::new_radix_server_key(&cks);
let compact_private_key = CompactPrivateKey::new(param_pke);
let pk = CompactPublicKey::new(&compact_private_key);
// Kept for consistency
let _casting_key =
KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), _param_casting);
// We have a use case with 320 bits of metadata
let mut metadata = [0u8; (320 / u8::BITS) as usize];
let mut rng = rand::thread_rng();
metadata.fill_with(|| rng.gen());
let zk_vers = param_pke.zk_scheme;
for proof_config in default_proof_config().iter() {
let msg_bits =
(param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize;
println!("Generating CRS... ");
let crs = CompactPkeCrs::from_shortint_params(
param_pke,
LweCiphertextCount(proof_config.crs_size / msg_bits),
)
.unwrap();
for bits in proof_config.bits_to_prove.iter() {
assert_eq!(bits % 64, 0);
// Packing, so we take the message and carry modulus to compute our block count
let num_block = 64usize.div_ceil(
(param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize,
);
let fhe_uint_count = bits / 64;
// for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] {
for compute_load in [ZkComputeLoad::Verify] {
let zk_load = match compute_load {
ZkComputeLoad::Proof => "compute_load_proof",
ZkComputeLoad::Verify => "compute_load_verify",
};
let bench_id;
match BENCH_TYPE.get().unwrap() {
BenchmarkType::Latency => {
bench_id = format!("{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}");
bench_group.bench_function(&bench_id, |b| {
let input_msg = rng.gen::<u64>();
let messages = vec![input_msg; fhe_uint_count];
b.iter(|| {
let _ct1 =
tfhe::integer::ProvenCompactCiphertextList::builder(&pk)
.extend(messages.iter().copied())
.build_with_proof_packed(&crs, &metadata, compute_load)
.unwrap();
})
});
}
BenchmarkType::Throughput => {
let elements = zk_throughput_num_elements() * 2; // This value, found empirically, ensure saturation of current target
// machine
bench_group.throughput(Throughput::Elements(elements));
bench_id = format!(
"{bench_name}::throughput::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
);
bench_group.bench_function(&bench_id, |b| {
let messages = (0..elements)
.map(|_| {
let input_msg = rng.gen::<u64>();
vec![input_msg; fhe_uint_count]
})
.collect::<Vec<_>>();
b.iter(|| {
messages.par_iter().for_each(|msg| {
tfhe::integer::ProvenCompactCiphertextList::builder(&pk)
.extend(msg.iter().copied())
.build_with_proof_packed(&crs, &metadata, compute_load)
.unwrap();
})
})
});
}
}
let shortint_params: PBSParameters = param_fhe.into();
write_to_json::<u64, _>(
&bench_id,
shortint_params,
param_name,
"pke_zk_proof",
&OperatorType::Atomic,
shortint_params.message_modulus().0 as u32,
vec![shortint_params.message_modulus().0.ilog2(); num_block],
);
}
}
}
}
bench_group.finish()
}
// criterion_group!(zk_proof, cpu_pke_zk_proof);
pub fn zk_proof() {
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
cpu_pke_zk_proof(&mut criterion);
}
fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
let bench_name = "zk::pke_zk_verify";
@@ -277,7 +282,8 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
vec![],
);
for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] {
// for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] {
for compute_load in [ZkComputeLoad::Verify] {
let zk_load = match compute_load {
ZkComputeLoad::Proof => "compute_load_proof",
ZkComputeLoad::Verify => "compute_load_verify",
@@ -457,6 +463,7 @@ pub fn zk_verify() {
fn main() {
BENCH_TYPE.get_or_init(|| BenchmarkType::from_env().unwrap());
zk_proof();
zk_verify();
Criterion::default().configure_from_args().final_summary();

View File

@@ -616,8 +616,8 @@ mod cuda_utils {
#[cfg(feature = "integer")]
pub mod cuda_integer_utils {
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
use tfhe::core_crypto::gpu::vec::GpuIndex;
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::ClientKey;
@@ -645,9 +645,7 @@ mod cuda_utils {
) -> Vec<CudaStreams> {
(0..cuda_num_streams(num_block))
.map(|i| {
CudaStreams::new_single_gpu(GpuIndex(
(i % get_number_of_gpus() as u64) as u32,
))
CudaStreams::new_single_gpu(GpuIndex((i % get_number_of_gpus() as u64) as u32))
})
.cycle()
.take(throughput_elements)