diff --git a/Cargo.toml b/Cargo.toml index d90e927..fe8a848 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,3 +57,7 @@ ark-serialize = { git = "https://github.com/arkworks-rs/algebra" } hex = "0.4" num-bigint = { version = "0.4.3", features = ["rand"] } criterion = "0.3" + +[[bench]] +name = "halo2_proof_generation" +harness = false \ No newline at end of file diff --git a/benches/halo2_proof_generation.rs b/benches/halo2_proof_generation.rs index 104d00e..e73cc46 100644 --- a/benches/halo2_proof_generation.rs +++ b/benches/halo2_proof_generation.rs @@ -1,90 +1,120 @@ use authdecode::halo2_backend::onetimesetup::OneTimeSetup; -use authdecode::halo2_backend::prover::Prover; -use authdecode::halo2_backend::verifier::Verifier; +use authdecode::halo2_backend::prover::{Prover, PK}; +use authdecode::halo2_backend::verifier::{Verifier, VK}; use authdecode::halo2_backend::Curve; -use authdecode::prover::AuthDecodeProver; -use authdecode::verifier::AuthDecodeVerifier; +use authdecode::prover::{AuthDecodeProver, ProofCreation}; +use authdecode::verifier::{AuthDecodeVerifier, VerifyMany}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::thread_rng; use rand::Rng; +use std::env; -fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("proof_generation", move |bench| { - // The Prover should have generated the proving key (before the authdecode - // protocol starts) like this: - let proving_key = OneTimeSetup::proving_key(); +pub fn criterion_benchmark(c: &mut Criterion) { + // benchmarking single threaded halo2 + env::set_var("RAYON_NUM_THREADS", "1"); - // The Verifier should have generated the verifying key (before the authdecode - // protocol starts) like this: - let verification_key = OneTimeSetup::verification_key(); + let proving_key = OneTimeSetup::proving_key(); + let verification_key = OneTimeSetup::verification_key(); - let prover = Box::new(Prover::new(proving_key)); - let verifier = Box::new(Verifier::new(verification_key, Curve::Pallas)); - let mut rng = thread_rng(); - - // generate random plaintext of random size up to 400 bytes - let plaintext: Vec = core::iter::repeat_with(|| rng.gen::()) - .take(thread_rng().gen_range(0..400)) - .collect(); - - // Normally, the Prover is expected to obtain her binary labels by - // evaluating the garbled circuit. - // To keep this test simple, we don't evaluate the gc, but we generate - // all labels of the Verifier and give the Prover her active labels. - let bit_size = plaintext.len() * 8; - let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size); - let mut delta: u128 = rng.gen(); - // set the last bit - delta |= 1; - for _ in 0..bit_size { - let label_zero: u128 = rng.gen(); - all_binary_labels.push([label_zero, label_zero ^ delta]); - } - let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext)); - - let verifier = AuthDecodeVerifier::new(all_binary_labels.clone(), verifier); - - let verifier = verifier.setup().unwrap(); - - let prover = AuthDecodeProver::new(plaintext, prover); - - // Perform setup - let prover = prover.setup().unwrap(); - - // Commitment to the plaintext is sent to the Notary - let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap(); - - // Notary sends back encrypted arithm. labels. - let (ciphertexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash).unwrap(); - - // Hash commitment to the label_sum is sent to the Notary - let (label_sum_hashes, prover) = prover - .label_sum_commitment(ciphertexts, &prover_labels) - .unwrap(); - - // Notary sends the arithmetic label seed - let (seed, verifier) = verifier.receive_label_sum_hashes(label_sum_hashes).unwrap(); - - // At this point the following happens in the `committed GC` protocol: - // - the Notary reveals the GC seed - // - the User checks that the GC was created from that seed - // - the User checks that her active output labels correspond to the - // output labels derived from the seed - // - we are called with the result of the check and (if successful) - // with all the output labels - - let prover = prover - .binary_labels_authenticated(true, Some(all_binary_labels)) - .unwrap(); - - // Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas - let prover = prover.authenticate_arithmetic_labels(seed).unwrap(); - - bench.iter(|| { - //Prover generates the proof - black_box(prover.create_zk_proofs()); - }); + c.bench_function("halo2_proof_generation_single_threaded", |b| { + b.iter(|| { + // Since we can't Clone provers, we generate a new prover for each + // iteration. This should not add more than 1-2% runtime to the bench + let (prover, _verifier) = create_prover(proving_key.clone(), verification_key.clone()); + black_box(prover.create_zk_proofs().unwrap()); + }) }); + + // We cannot bench proof verification without running the proof generation. + // To get the actual verification time, subtract from "generation+verification" + // time the "generation only" time from the above bench. + + c.bench_function( + "halo2_proof_generation_and_verification_single_threaded", + |b| { + b.iter(|| { + // Since we can't Clone prover, verifier, we generate a new prover and a new verifier + // for each iteration. This should not add more than 1-2% runtime to the bench + let (prover, verifier) = + create_prover(proving_key.clone(), verification_key.clone()); + let (proofs, _salts) = prover.create_zk_proofs().unwrap(); + black_box(verifier.verify_many(proofs.clone()).unwrap()); + }) + }, + ); +} + +// Runs the whole protocol and returns the prover in a state ready to create +// proofs and a verifier ready to verify proofs. +fn create_prover( + proving_key: PK, + verification_key: VK, +) -> ( + AuthDecodeProver, + AuthDecodeVerifier, +) { + let prover = Box::new(Prover::new(proving_key)); + let verifier = Box::new(Verifier::new(verification_key, Curve::Pallas)); + let mut rng = thread_rng(); + + // generate random plaintext of random size up to 400 bytes + let plaintext: Vec = core::iter::repeat_with(|| rng.gen::()) + .take(thread_rng().gen_range(1..400)) + .collect(); + + // Normally, the Prover is expected to obtain her binary labels by + // evaluating the garbled circuit. + // To keep this test simple, we don't evaluate the gc, but we generate + // all labels of the Verifier and give the Prover her active labels. + let bit_size = plaintext.len() * 8; + let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size); + let mut delta: u128 = rng.gen(); + // set the last bit + delta |= 1; + for _ in 0..bit_size { + let label_zero: u128 = rng.gen(); + all_binary_labels.push([label_zero, label_zero ^ delta]); + } + let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext)); + + let verifier = AuthDecodeVerifier::new(all_binary_labels.clone(), verifier); + + let verifier = verifier.setup().unwrap(); + + let prover = AuthDecodeProver::new(plaintext, prover); + + // Perform setup + let prover = prover.setup().unwrap(); + + // Commitment to the plaintext is sent to the Notary + let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap(); + + // Notary sends back encrypted arithm. labels. + let (ciphertexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash).unwrap(); + + // Hash commitment to the label_sum is sent to the Notary + let (label_sum_hashes, prover) = prover + .label_sum_commitment(ciphertexts, &prover_labels) + .unwrap(); + + // Notary sends the arithmetic label seed + let (seed, verifier) = verifier.receive_label_sum_hashes(label_sum_hashes).unwrap(); + + // At this point the following happens in the `committed GC` protocol: + // - the Notary reveals the GC seed + // - the User checks that the GC was created from that seed + // - the User checks that her active output labels correspond to the + // output labels derived from the seed + // - we are called with the result of the check and (if successful) + // with all the output labels + + let prover = prover + .binary_labels_authenticated(true, Some(all_binary_labels)) + .unwrap(); + + // Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas + let prover = prover.authenticate_arithmetic_labels(seed).unwrap(); + (prover, verifier) } /// Unzips a slice of pairs, returning items corresponding to choice diff --git a/benches/testbench.rs b/benches/testbench.rs new file mode 100644 index 0000000..afbb499 --- /dev/null +++ b/benches/testbench.rs @@ -0,0 +1,9 @@ +use authdecode::fibonacci; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +pub fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("test2", |b| b.iter(|| fibonacci(black_box(20)))); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/src/halo2_backend/circuit.rs b/src/halo2_backend/circuit.rs index b3fee3b..6b9c6f3 100644 --- a/src/halo2_backend/circuit.rs +++ b/src/halo2_backend/circuit.rs @@ -492,15 +492,15 @@ impl Circuit for AuthDecodeCircuit { &cfg, offset, )?; - offset += 1; + + // uncomment this if in the future we may want to do more + // computations in the scratch space + // offset += 1; // replace the last field element with the one with salt plaintext[pt_len - 1] = last_with_salt; - println!("{:?} final `scratch_space` offset", offset); - - //Ok((label_sum, plaintext)) - + //println!("{:?} final `scratch_space` offset", offset); Ok((label_sum_salted, plaintext)) }, )?; diff --git a/src/halo2_backend/mod.rs b/src/halo2_backend/mod.rs index 782a7e6..6e77a3c 100644 --- a/src/halo2_backend/mod.rs +++ b/src/halo2_backend/mod.rs @@ -47,14 +47,42 @@ mod tests { } #[test] + // As of Oct 2022 there appears to be a bug in halo2 which causes the prove + // times with MockProver be as long as with a real prover. Marking this test + // as expensive. + #[ignore = "expensive"] /// Tests that the protocol runs successfully fn halo2_e2e_test_success() { - halo2_e2e_test(false); + // This test causes the "thread ... has overflowed its stack" error + // The only way to increase the stack size is to spawn a new thread with + // the test. + // See https://github.com/rust-lang/rustfmt/issues/3473 + use std::thread; + thread::Builder::new() + .stack_size(8388608) + .spawn(|| halo2_e2e_test(false)) + .expect("Failed to create a test thread") + .join() + .expect("Failed to join a test thread"); } #[test] + // As of Oct 2022 there appears to be a bug in halo2 which causes the prove + // times with MockProver be as long as with a real prover. Marking this test + // as expensive. + #[ignore = "expensive"] /// Tests that a corrupted proof causes verification to fail fn halo2_e2e_test_failure() { - halo2_e2e_test(true); + // This test causes the "thread ... has overflowed its stack" error + // The only way to increase the stack size is to spawn a new thread with + // the test. + // See https://github.com/rust-lang/rustfmt/issues/3473 + use std::thread; + thread::Builder::new() + .stack_size(8388608) + .spawn(|| halo2_e2e_test(true)) + .expect("Failed to create a test thread") + .join() + .expect("Failed to join a test thread"); } } diff --git a/src/halo2_backend/prover.rs b/src/halo2_backend/prover.rs index 040b644..f7690cd 100644 --- a/src/halo2_backend/prover.rs +++ b/src/halo2_backend/prover.rs @@ -7,7 +7,6 @@ use halo2_proofs::plonk; use halo2_proofs::plonk::ProvingKey; use halo2_proofs::poly::commitment::Params; use halo2_proofs::transcript::{Blake2bWrite, Challenge255}; -use instant::Instant; use num::BigUint; use pasta_curves::pallas::Base as F; use pasta_curves::EqAffine; @@ -60,8 +59,6 @@ impl Prove for Prover { ]; all_inputs.push(tmp); - let now = Instant::now(); - // prepare the proving system and generate the proof: let circuit = @@ -74,6 +71,8 @@ impl Prove for Prover { let mut rng = thread_rng(); + // let now = Instant::now(); + let res = plonk::create_proof( params, pk, @@ -86,9 +85,9 @@ impl Prove for Prover { return Err(ProverError::ProvingBackendError); } - println!("Proof created [{:?}]", now.elapsed()); + // println!("Proof created [{:?}]", now.elapsed()); let proof = transcript.finalize(); - println!("Proof size [{} kB]", proof.len() as f64 / 1024.0); + // println!("Proof size [{} kB]", proof.len() as f64 / 1024.0); Ok(proof) } @@ -154,8 +153,9 @@ fn hash_internal(inputs: &Vec) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::halo2_backend::circuit::{CELLS_PER_ROW, K, USEFUL_ROWS}; + use crate::halo2_backend::circuit::{CELLS_PER_ROW, K}; use crate::halo2_backend::prover::hash_internal; + use crate::halo2_backend::utils::bigint_to_256bits; use crate::halo2_backend::Curve; use crate::prover::{ProofInput, Prove, ProverError}; use crate::tests::run_until_proofs_are_generated; @@ -163,7 +163,6 @@ mod tests { use crate::Proof; use halo2_proofs::dev::MockProver; use num::BigUint; - use std::panic::catch_unwind; /// TestHalo2Prover is a test prover. It is the same as [Prover] except: /// - it doesn't require a proving key @@ -207,41 +206,52 @@ mod tests { ); // Test with the correct inputs. + // Expect successful verification. - println!("start mockprover"); let prover = MockProver::run(K, &circuit, good_inputs.clone()).unwrap(); assert!(prover.verify().is_ok()); - println!("end mockprover"); - // Corrupt at least one delta which corresponds to plaintext bit 1. - // Since the plaintext was chosen randomly, we corrupt only the last - // deltas on each row - one of those deltas will correspond to a plaintext - // bit 1 with high probability. - // Expect halo2 to panic. - let mut bad_input1 = good_inputs.clone(); - for i in 0..USEFUL_ROWS { - bad_input1[CELLS_PER_ROW - 1][i] = F::from(123); + // Find one delta which corresponds to plaintext bit 1 and corrupt + // the delta: + + // Find the first bit 1 in plaintext + let bits = bigint_to_256bits(input.plaintext[0].clone()); + let mut offset: i32 = -1; + for (i, b) in bits.iter().enumerate() { + if *b == true { + offset = i as i32; + break; + } } + // first field element of the plaintext is not expected to have all + // bits set to zero. + assert!(offset != -1); + let offset = offset as usize; + + // Find the position of the corresponding delta. The position is + // row/column in the halo2 table + let col = offset % CELLS_PER_ROW; + let row = offset / CELLS_PER_ROW; + + // Corrupt the delta + let mut bad_input1 = good_inputs.clone(); + bad_input1[col][row] = F::from(123); - println!("start mockprover2"); let prover = MockProver::run(K, &circuit, bad_input1.clone()).unwrap(); assert!(prover.verify().is_err()); - println!("end mockprover2"); // One-by-one corrupt the plaintext hash, the label sum hash, the zero sum. - // Expect halo2 to panic. + // Expect verification error. for i in 0..3 { let mut bad_public_input = good_inputs.clone(); bad_public_input[CELLS_PER_ROW][i] = F::from(123); - println!("start mockprover3"); let prover = MockProver::run(K, &circuit, bad_public_input.clone()).unwrap(); assert!(prover.verify().is_err()); - println!("end mockprover3"); } // Corrupt only the plaintext. - // Expect halo2 to panic. + // Expect verification error. let mut bad_plaintext = good_plaintext.clone(); bad_plaintext[0] = F::from(123); @@ -250,13 +260,11 @@ mod tests { bigint_to_f(&input.salt), deltas_as_rows.into(), ); - println!("start mockprover4"); let prover = MockProver::run(K, &circuit, good_inputs.clone()).unwrap(); assert!(prover.verify().is_err()); - println!("end mockprover4"); // Corrupt only the salt. - // Expect halo2 to panic. + // Expect verification error. let bad_salt = BigUint::from(123u8); let circuit = AuthDecodeCircuit::new( @@ -264,10 +272,8 @@ mod tests { bigint_to_f(&bad_salt), deltas_as_rows.into(), ); - println!("start mockprover5"); let prover = MockProver::run(K, &circuit, good_inputs.clone()).unwrap(); assert!(prover.verify().is_err()); - println!("end mockprover5"); Ok(Default::default()) } @@ -325,7 +331,6 @@ mod tests { match self.curve { Curve::Pallas => 255, Curve::BN254 => 254, - _ => panic!("a new curve was added. Add its field size here."), } } @@ -339,11 +344,27 @@ mod tests { } #[test] + // As of Oct 2022 there appears to be a bug in halo2 which causes the prove + // times with MockProver be as long as with a real prover. Marking this test + // as expensive. + #[ignore = "expensive"] /// Tests the circuit with the correct inputs as well as wrong inputs. The logic is /// in [TestHalo2Prover]'s prove() fn test_circuit() { - let prover = Box::new(TestHalo2Prover::new()); - let verifier = Box::new(TestHalo2Verifier::new(Curve::Pallas)); - let _res = run_until_proofs_are_generated(prover, verifier); + // This test causes the "thread ... has overflowed its stack" error + // The only way to increase the stack size is to spawn a new thread with + // the test. + // See https://github.com/rust-lang/rustfmt/issues/3473 + use std::thread; + thread::Builder::new() + .stack_size(8388608) + .spawn(|| { + let prover = Box::new(TestHalo2Prover::new()); + let verifier = Box::new(TestHalo2Verifier::new(Curve::Pallas)); + let _ = run_until_proofs_are_generated(prover, verifier); + }) + .expect("Failed to create a test thread") + .join() + .expect("Failed to join a test thread"); } } diff --git a/src/halo2_backend/utils.rs b/src/halo2_backend/utils.rs index 6c3d43d..d8c3e31 100644 --- a/src/halo2_backend/utils.rs +++ b/src/halo2_backend/utils.rs @@ -2,7 +2,7 @@ use super::circuit::{CELLS_PER_ROW, USEFUL_ROWS}; use crate::utils::{boolvec_to_u8vec, u8vec_to_boolvec}; use crate::Delta; use halo2_proofs::arithmetic::FieldExt; -use num::{BigUint, FromPrimitive}; +use num::BigUint; use pasta_curves::Fp as F; /// Decomposes a `BigUint` into bits and returns the bits in MSB-first bit order, @@ -86,7 +86,7 @@ fn test_f_to_bigint() { let b = rng.gen::(); let res = f_to_bigint(&(F::from_u128(a) + F::from_u128(b))); - let expected: BigUint = BigUint::from_u128(a).unwrap() + BigUint::from_u128(b).unwrap(); + let expected: BigUint = BigUint::from(a) + BigUint::from(b); assert_eq!(res, expected); } diff --git a/src/halo2_backend/verifier.rs b/src/halo2_backend/verifier.rs index 0678595..9e75e20 100644 --- a/src/halo2_backend/verifier.rs +++ b/src/halo2_backend/verifier.rs @@ -7,7 +7,6 @@ use halo2_proofs::plonk::VerifyingKey; use halo2_proofs::poly::commitment::Params; use halo2_proofs::transcript::Blake2bRead; use halo2_proofs::transcript::Challenge255; -use instant::Instant; use pasta_curves::pallas::Base as F; use pasta_curves::EqAffine; @@ -55,7 +54,7 @@ impl Verify for Verifier { ]; all_inputs.push(tmp); - let now = Instant::now(); + // let now = Instant::now(); // perform the actual verification let res = plonk::verify_proof( params, @@ -64,7 +63,7 @@ impl Verify for Verifier { &[all_inputs.as_slice()], &mut transcript, ); - println!("Proof verified [{:?}]", now.elapsed()); + // println!("Proof verified [{:?}]", now.elapsed()); if res.is_err() { return Err(VerifierError::VerificationFailed); } else { @@ -76,7 +75,6 @@ impl Verify for Verifier { match self.curve { Curve::Pallas => 255, Curve::BN254 => 254, - _ => panic!("a new curve was added. Add its field size here."), } } diff --git a/src/lib.rs b/src/lib.rs index 9410221..82be38d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -133,7 +133,7 @@ mod tests { // generate random plaintext of random size up to 1000 bytes let plaintext: Vec = core::iter::repeat_with(|| rng.gen::()) - .take(thread_rng().gen_range(0..1000)) + .take(thread_rng().gen_range(1..1000)) .collect(); // Normally, the Prover is expected to obtain her binary labels by diff --git a/src/prover.rs b/src/prover.rs index de4d024..3208faf 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -182,6 +182,7 @@ pub trait Prove { fn hash(&self, inputs: &Vec) -> Result; } +/// Implementation of the prover in the AuthDecode protocol. pub struct AuthDecodeProver where S: State, @@ -831,13 +832,13 @@ mod tests { #[test] /// Sets too few binary labels and triggers [ProverError::IncorrectBinaryLabelSize] fn test_error_incorrect_binary_label_size() { - let pt_len = 1000; - let ciphertexts = vec![[[0u8; 16], [0u8; 16]]; pt_len * 8]; + let plaintext_size = 1000; + let ciphertexts = vec![[[0u8; 16], [0u8; 16]]; plaintext_size]; let labels = vec![0u128]; let lsp = AuthDecodeProver { state: LabelSumCommitment { - plaintext_size: 0, + plaintext_size, ..Default::default() }, prover: Box::new(CorrectTestProver {}), @@ -846,7 +847,7 @@ mod tests { assert_eq!( res.err().unwrap(), - ProverError::IncorrectBinaryLabelSize(pt_len, 1) + ProverError::IncorrectBinaryLabelSize(plaintext_size, 1) ); } diff --git a/src/utils.rs b/src/utils.rs index 93b3ad7..2ee0c4f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,5 @@ use crate::{Delta, ZeroSum}; use aes::{Aes128, NewBlockCipher}; -use ark_ff::BigInt; use cipher::{consts::U16, generic_array::GenericArray, BlockEncrypt}; use num::BigUint; use sha2::{Digest, Sha256}; diff --git a/src/verifier.rs b/src/verifier.rs index 00666c3..bf0437b 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -62,11 +62,6 @@ pub struct VerifyMany { } impl State for VerifyMany {} -pub struct VerificationSuccessfull { - plaintext_hashes: Vec, -} -impl State for VerificationSuccessfull {} - pub trait Verify { /// Verifies the zk proof against public `input`s. Returns `true` on success, /// `false` otherwise. @@ -85,6 +80,8 @@ pub trait Verify { /// of the last field element of each chunk. fn chunk_size(&self) -> usize; } + +/// Implementation of the verifier in the AuthDecode protocol. pub struct AuthDecodeVerifier where S: State,