From 524a5df36e75b8703310844613f31a76fee09aa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Test=C3=A9?= Date: Mon, 17 Nov 2025 15:39:41 +0100 Subject: [PATCH] WIP: put back zk proof bench and run only 256 bits proof --- Cargo.toml | 2 +- tfhe/benches/integer/zk_pke.rs | 274 +++++++++--------- .../noise_squashing/p_fail_2_minus_128/mod.rs | 16 +- 3 files changed, 143 insertions(+), 149 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a2016975c..ab27ba6b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = [ ] [workspace.dependencies] aligned-vec = { version = "0.6", default-features = false } -bytemuck = "1.14.3" +bytemuck = "<1.24" dyn-stack = { version = "0.11", default-features = false } itertools = "0.14" num-complex = "0.4" diff --git a/tfhe/benches/integer/zk_pke.rs b/tfhe/benches/integer/zk_pke.rs index 253e28c31..bfc3f34c6 100644 --- a/tfhe/benches/integer/zk_pke.rs +++ b/tfhe/benches/integer/zk_pke.rs @@ -16,7 +16,6 @@ use tfhe::keycache::NamedParam; use tfhe::shortint::parameters::*; use tfhe::zk::{CompactPkeCrs, ZkComputeLoad}; - struct ProofConfig { crs_size: usize, bits_to_prove: Vec, @@ -35,12 +34,12 @@ fn default_proof_config() -> Vec { vec![ // ProofConfig::new(64usize, &[64usize]), // ProofConfig::new(640, &[640]), - ProofConfig::new(2048, &[2048, 3 * 64usize]), + // ProofConfig::new(2048, &[2048, 4 * 64usize]), + ProofConfig::new(2048, &[4 * 64usize]), // ProofConfig::new(4096, &[4096]), ] } - fn write_result(file: &mut File, name: &str, value: usize) { let line = format!("{name},{value}\n"); let error_message = format!("cannot write {name} result into file"); @@ -65,135 +64,135 @@ fn zk_throughput_num_elements() -> u64 { } } +fn cpu_pke_zk_proof(c: &mut Criterion) { + let bench_name = "zk::pke_zk_proof"; + let mut bench_group = c.benchmark_group(bench_name); + bench_group + .sample_size(15) + .measurement_time(std::time::Duration::from_secs(60)); + + for (param_pke, _param_casting, param_fhe) in [ + ( + PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + ), + // ( + // BENCH_PARAM_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1, + // BENCH_PARAM_KEYSWITCH_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1, + // BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + // ), + ] { + let param_name = param_fhe.name(); + let param_name = param_name.as_str(); + let cks = ClientKey::new(param_fhe); + let sks = ServerKey::new_radix_server_key(&cks); + let compact_private_key = CompactPrivateKey::new(param_pke); + let pk = CompactPublicKey::new(&compact_private_key); + // Kept for consistency + let _casting_key = + KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), _param_casting); + + // We have a use case with 320 bits of metadata + let mut metadata = [0u8; (320 / u8::BITS) as usize]; + let mut rng = rand::thread_rng(); + metadata.fill_with(|| rng.gen()); + + let zk_vers = param_pke.zk_scheme; + + for proof_config in default_proof_config().iter() { + let msg_bits = + (param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize; + println!("Generating CRS... "); + let crs = CompactPkeCrs::from_shortint_params( + param_pke, + LweCiphertextCount(proof_config.crs_size / msg_bits), + ) + .unwrap(); + + for bits in proof_config.bits_to_prove.iter() { + assert_eq!(bits % 64, 0); + // Packing, so we take the message and carry modulus to compute our block count + let num_block = 64usize.div_ceil( + (param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize, + ); + + let fhe_uint_count = bits / 64; + + // for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] { + for compute_load in [ZkComputeLoad::Verify] { + let zk_load = match compute_load { + ZkComputeLoad::Proof => "compute_load_proof", + ZkComputeLoad::Verify => "compute_load_verify", + }; + + let bench_id; + + match get_bench_type() { + BenchmarkType::Latency => { + bench_id = format!( + "{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}" + ); + bench_group.bench_function(&bench_id, |b| { + let input_msg = rng.gen::(); + let messages = vec![input_msg; fhe_uint_count]; + + b.iter(|| { + let _ct1 = + tfhe::integer::ProvenCompactCiphertextList::builder(&pk) + .extend(messages.iter().copied()) + .build_with_proof_packed(&crs, &metadata, compute_load) + .unwrap(); + }) + }); + } + BenchmarkType::Throughput => { + let elements = zk_throughput_num_elements() * 2; // This value, found empirically, ensure saturation of current target + // machine + bench_group.throughput(Throughput::Elements(elements)); + + bench_id = format!( + "{bench_name}::throughput::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}" + ); + bench_group.bench_function(&bench_id, |b| { + let messages = (0..elements) + .map(|_| { + let input_msg = rng.gen::(); + vec![input_msg; fhe_uint_count] + }) + .collect::>(); + + b.iter(|| { + messages.par_iter().for_each(|msg| { + tfhe::integer::ProvenCompactCiphertextList::builder(&pk) + .extend(msg.iter().copied()) + .build_with_proof_packed(&crs, &metadata, compute_load) + .unwrap(); + }) + }) + }); + } + } + + let shortint_params: PBSParameters = param_fhe.into(); + + write_to_json::( + &bench_id, + shortint_params, + param_name, + "pke_zk_proof", + &OperatorType::Atomic, + shortint_params.message_modulus().0 as u32, + vec![shortint_params.message_modulus().0.ilog2(); num_block], + ); + } + } + } + } + + bench_group.finish() +} -// fn cpu_pke_zk_proof(c: &mut Criterion) { -// let bench_name = "zk::pke_zk_proof"; -// let mut bench_group = c.benchmark_group(bench_name); -// bench_group -// .sample_size(15) -// .measurement_time(std::time::Duration::from_secs(60)); -// -// for (param_pke, _param_casting, param_fhe) in [ -// ( -// BENCH_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, -// BENCH_PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, -// BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, -// ), -// ( -// BENCH_PARAM_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1, -// BENCH_PARAM_KEYSWITCH_PKE_TO_SMALL_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1, -// BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, -// ), -// ] { -// let param_name = param_fhe.name(); -// let param_name = param_name.as_str(); -// let cks = ClientKey::new(param_fhe); -// let sks = ServerKey::new_radix_server_key(&cks); -// let compact_private_key = CompactPrivateKey::new(param_pke); -// let pk = CompactPublicKey::new(&compact_private_key); -// // Kept for consistency -// let _casting_key = -// KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), _param_casting); -// -// // We have a use case with 320 bits of metadata -// let mut metadata = [0u8; (320 / u8::BITS) as usize]; -// let mut rng = rand::thread_rng(); -// metadata.fill_with(|| rng.gen()); -// -// let zk_vers = param_pke.zk_scheme; -// -// for proof_config in default_proof_config().iter() { -// let msg_bits = -// (param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize; -// println!("Generating CRS... "); -// let crs = CompactPkeCrs::from_shortint_params( -// param_pke, -// LweCiphertextCount(proof_config.crs_size / msg_bits), -// ) -// .unwrap(); -// -// for bits in proof_config.bits_to_prove.iter() { -// assert_eq!(bits % 64, 0); -// // Packing, so we take the message and carry modulus to compute our block count -// let num_block = 64usize.div_ceil( -// (param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize, -// ); -// -// let fhe_uint_count = bits / 64; -// -// for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] { -// let zk_load = match compute_load { -// ZkComputeLoad::Proof => "compute_load_proof", -// ZkComputeLoad::Verify => "compute_load_verify", -// }; -// -// let bench_id; -// -// match get_bench_type() { -// BenchmarkType::Latency => { -// bench_id = format!( -// "{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}" -// ); -// bench_group.bench_function(&bench_id, |b| { -// let input_msg = rng.gen::(); -// let messages = vec![input_msg; fhe_uint_count]; -// -// b.iter(|| { -// let _ct1 = -// tfhe::integer::ProvenCompactCiphertextList::builder(&pk) -// .extend(messages.iter().copied()) -// .build_with_proof_packed(&crs, &metadata, compute_load) -// .unwrap(); -// }) -// }); -// } -// BenchmarkType::Throughput => { -// let elements = zk_throughput_num_elements() * 2; // This value, found empirically, ensure saturation of current target -// // machine -// bench_group.throughput(Throughput::Elements(elements)); -// -// bench_id = format!( -// "{bench_name}::throughput::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}" -// ); -// bench_group.bench_function(&bench_id, |b| { -// let messages = (0..elements) -// .map(|_| { -// let input_msg = rng.gen::(); -// vec![input_msg; fhe_uint_count] -// }) -// .collect::>(); -// -// b.iter(|| { -// messages.par_iter().for_each(|msg| { -// tfhe::integer::ProvenCompactCiphertextList::builder(&pk) -// .extend(msg.iter().copied()) -// .build_with_proof_packed(&crs, &metadata, compute_load) -// .unwrap(); -// }) -// }) -// }); -// } -// } -// -// let shortint_params: PBSParameters = param_fhe.into(); -// -// write_to_json::( -// &bench_id, -// shortint_params, -// param_name, -// "pke_zk_proof", -// &OperatorType::Atomic, -// shortint_params.message_modulus().0 as u32, -// vec![shortint_params.message_modulus().0.ilog2(); num_block], -// ); -// } -// } -// } -// } -// -// bench_group.finish() -// } -// // criterion_group!(zk_proof, cpu_pke_zk_proof); fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) { @@ -245,7 +244,7 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) { param_pke, LweCiphertextCount(proof_config.crs_size / msg_bits), ) - .unwrap(); + .unwrap(); for bits in proof_config.bits_to_prove.iter() { assert_eq!(bits % 64, 0); @@ -276,7 +275,8 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) { vec![], ); - for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] { + // for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] { + for compute_load in [ZkComputeLoad::Verify] { let zk_load = match compute_load { ZkComputeLoad::Proof => "compute_load_proof", ZkComputeLoad::Verify => "compute_load_verify", @@ -447,10 +447,6 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) { bench_group.finish() } - - - - #[cfg(all(feature = "gpu", feature = "zk-pok"))] mod cuda { use super::*; @@ -506,7 +502,7 @@ mod cuda { param_pke, LweCiphertextCount(proof_config.crs_size / msg_bits), ) - .unwrap(); + .unwrap(); use rand::Rng; let mut rng = rand::thread_rng(); @@ -811,7 +807,7 @@ pub fn zk_verify_and_proof() { let results_file = Path::new("pke_zk_crs_sizes.csv"); let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args(); cpu_pke_zk_verify(&mut criterion, results_file); - // cpu_pke_zk_proof(&mut criterion); + cpu_pke_zk_proof(&mut criterion); } #[cfg(all(feature = "gpu", feature = "zk-pok"))] diff --git a/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs b/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs index 0ccceedf1..443e91b4a 100644 --- a/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs +++ b/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs @@ -24,20 +24,18 @@ pub const V1_1_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: }; pub const V1_3_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: -NoiseSquashingParameters = NoiseSquashingParameters { + NoiseSquashingParameters = NoiseSquashingParameters { glwe_dimension: GlweDimension(2), polynomial_size: PolynomialSize(2048), glwe_noise_distribution: DynamicDistribution::new_t_uniform(30), decomp_base_log: DecompositionBaseLog(24), decomp_level_count: DecompositionLevelCount(3), - modulus_switch_noise_reduction_params: Some( - ModulusSwitchNoiseReductionParams { - modulus_switch_zeros_count: LweCiphertextCount(1449), - ms_bound: NoiseEstimationMeasureBound(288230376151711744f64), - ms_r_sigma_factor: RSigmaFactor(13.179852282053789f64), - ms_input_variance: Variance(2.63039184094559E-7f64), - }, - ), + modulus_switch_noise_reduction_params: Some(ModulusSwitchNoiseReductionParams { + modulus_switch_zeros_count: LweCiphertextCount(1449), + ms_bound: NoiseEstimationMeasureBound(288230376151711744f64), + ms_r_sigma_factor: RSigmaFactor(13.179852282053789f64), + ms_input_variance: Variance(2.63039184094559E-7f64), + }), message_modulus: MessageModulus(4), carry_modulus: CarryModulus(4), ciphertext_modulus: CoreCiphertextModulus::::new_native(),