cleanups before PR upstream

This commit is contained in:
themighty1
2022-10-13 16:16:56 +03:00
parent 567537cc37
commit 6c1b1e3e63
12 changed files with 221 additions and 134 deletions

View File

@@ -57,3 +57,7 @@ ark-serialize = { git = "https://github.com/arkworks-rs/algebra" }
hex = "0.4"
num-bigint = { version = "0.4.3", features = ["rand"] }
criterion = "0.3"
[[bench]]
name = "halo2_proof_generation"
harness = false

View File

@@ -1,90 +1,120 @@
use authdecode::halo2_backend::onetimesetup::OneTimeSetup;
use authdecode::halo2_backend::prover::Prover;
use authdecode::halo2_backend::verifier::Verifier;
use authdecode::halo2_backend::prover::{Prover, PK};
use authdecode::halo2_backend::verifier::{Verifier, VK};
use authdecode::halo2_backend::Curve;
use authdecode::prover::AuthDecodeProver;
use authdecode::verifier::AuthDecodeVerifier;
use authdecode::prover::{AuthDecodeProver, ProofCreation};
use authdecode::verifier::{AuthDecodeVerifier, VerifyMany};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::thread_rng;
use rand::Rng;
use std::env;
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("proof_generation", move |bench| {
// The Prover should have generated the proving key (before the authdecode
// protocol starts) like this:
let proving_key = OneTimeSetup::proving_key();
pub fn criterion_benchmark(c: &mut Criterion) {
// benchmarking single threaded halo2
env::set_var("RAYON_NUM_THREADS", "1");
// The Verifier should have generated the verifying key (before the authdecode
// protocol starts) like this:
let verification_key = OneTimeSetup::verification_key();
let proving_key = OneTimeSetup::proving_key();
let verification_key = OneTimeSetup::verification_key();
let prover = Box::new(Prover::new(proving_key));
let verifier = Box::new(Verifier::new(verification_key, Curve::Pallas));
let mut rng = thread_rng();
// generate random plaintext of random size up to 400 bytes
let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
.take(thread_rng().gen_range(0..400))
.collect();
// Normally, the Prover is expected to obtain her binary labels by
// evaluating the garbled circuit.
// To keep this test simple, we don't evaluate the gc, but we generate
// all labels of the Verifier and give the Prover her active labels.
let bit_size = plaintext.len() * 8;
let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
let mut delta: u128 = rng.gen();
// set the last bit
delta |= 1;
for _ in 0..bit_size {
let label_zero: u128 = rng.gen();
all_binary_labels.push([label_zero, label_zero ^ delta]);
}
let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
let verifier = AuthDecodeVerifier::new(all_binary_labels.clone(), verifier);
let verifier = verifier.setup().unwrap();
let prover = AuthDecodeProver::new(plaintext, prover);
// Perform setup
let prover = prover.setup().unwrap();
// Commitment to the plaintext is sent to the Notary
let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
// Notary sends back encrypted arithm. labels.
let (ciphertexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash).unwrap();
// Hash commitment to the label_sum is sent to the Notary
let (label_sum_hashes, prover) = prover
.label_sum_commitment(ciphertexts, &prover_labels)
.unwrap();
// Notary sends the arithmetic label seed
let (seed, verifier) = verifier.receive_label_sum_hashes(label_sum_hashes).unwrap();
// At this point the following happens in the `committed GC` protocol:
// - the Notary reveals the GC seed
// - the User checks that the GC was created from that seed
// - the User checks that her active output labels correspond to the
// output labels derived from the seed
// - we are called with the result of the check and (if successful)
// with all the output labels
let prover = prover
.binary_labels_authenticated(true, Some(all_binary_labels))
.unwrap();
// Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
bench.iter(|| {
//Prover generates the proof
black_box(prover.create_zk_proofs());
});
c.bench_function("halo2_proof_generation_single_threaded", |b| {
b.iter(|| {
// Since we can't Clone provers, we generate a new prover for each
// iteration. This should not add more than 1-2% runtime to the bench
let (prover, _verifier) = create_prover(proving_key.clone(), verification_key.clone());
black_box(prover.create_zk_proofs().unwrap());
})
});
// We cannot bench proof verification without running the proof generation.
// To get the actual verification time, subtract from "generation+verification"
// time the "generation only" time from the above bench.
c.bench_function(
"halo2_proof_generation_and_verification_single_threaded",
|b| {
b.iter(|| {
// Since we can't Clone prover, verifier, we generate a new prover and a new verifier
// for each iteration. This should not add more than 1-2% runtime to the bench
let (prover, verifier) =
create_prover(proving_key.clone(), verification_key.clone());
let (proofs, _salts) = prover.create_zk_proofs().unwrap();
black_box(verifier.verify_many(proofs.clone()).unwrap());
})
},
);
}
// Runs the whole protocol and returns the prover in a state ready to create
// proofs and a verifier ready to verify proofs.
fn create_prover(
proving_key: PK,
verification_key: VK,
) -> (
AuthDecodeProver<ProofCreation>,
AuthDecodeVerifier<VerifyMany>,
) {
let prover = Box::new(Prover::new(proving_key));
let verifier = Box::new(Verifier::new(verification_key, Curve::Pallas));
let mut rng = thread_rng();
// generate random plaintext of random size up to 400 bytes
let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
.take(thread_rng().gen_range(1..400))
.collect();
// Normally, the Prover is expected to obtain her binary labels by
// evaluating the garbled circuit.
// To keep this test simple, we don't evaluate the gc, but we generate
// all labels of the Verifier and give the Prover her active labels.
let bit_size = plaintext.len() * 8;
let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
let mut delta: u128 = rng.gen();
// set the last bit
delta |= 1;
for _ in 0..bit_size {
let label_zero: u128 = rng.gen();
all_binary_labels.push([label_zero, label_zero ^ delta]);
}
let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
let verifier = AuthDecodeVerifier::new(all_binary_labels.clone(), verifier);
let verifier = verifier.setup().unwrap();
let prover = AuthDecodeProver::new(plaintext, prover);
// Perform setup
let prover = prover.setup().unwrap();
// Commitment to the plaintext is sent to the Notary
let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
// Notary sends back encrypted arithm. labels.
let (ciphertexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash).unwrap();
// Hash commitment to the label_sum is sent to the Notary
let (label_sum_hashes, prover) = prover
.label_sum_commitment(ciphertexts, &prover_labels)
.unwrap();
// Notary sends the arithmetic label seed
let (seed, verifier) = verifier.receive_label_sum_hashes(label_sum_hashes).unwrap();
// At this point the following happens in the `committed GC` protocol:
// - the Notary reveals the GC seed
// - the User checks that the GC was created from that seed
// - the User checks that her active output labels correspond to the
// output labels derived from the seed
// - we are called with the result of the check and (if successful)
// with all the output labels
let prover = prover
.binary_labels_authenticated(true, Some(all_binary_labels))
.unwrap();
// Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
(prover, verifier)
}
/// Unzips a slice of pairs, returning items corresponding to choice

9
benches/testbench.rs Normal file
View File

@@ -0,0 +1,9 @@
use authdecode::fibonacci;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
pub fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("test2", |b| b.iter(|| fibonacci(black_box(20))));
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@@ -492,15 +492,15 @@ impl Circuit<F> for AuthDecodeCircuit {
&cfg,
offset,
)?;
offset += 1;
// uncomment this if in the future we may want to do more
// computations in the scratch space
// offset += 1;
// replace the last field element with the one with salt
plaintext[pt_len - 1] = last_with_salt;
println!("{:?} final `scratch_space` offset", offset);
//Ok((label_sum, plaintext))
//println!("{:?} final `scratch_space` offset", offset);
Ok((label_sum_salted, plaintext))
},
)?;

View File

@@ -47,14 +47,42 @@ mod tests {
}
#[test]
// As of Oct 2022 there appears to be a bug in halo2 which causes the prove
// times with MockProver be as long as with a real prover. Marking this test
// as expensive.
#[ignore = "expensive"]
/// Tests that the protocol runs successfully
fn halo2_e2e_test_success() {
halo2_e2e_test(false);
// This test causes the "thread ... has overflowed its stack" error
// The only way to increase the stack size is to spawn a new thread with
// the test.
// See https://github.com/rust-lang/rustfmt/issues/3473
use std::thread;
thread::Builder::new()
.stack_size(8388608)
.spawn(|| halo2_e2e_test(false))
.expect("Failed to create a test thread")
.join()
.expect("Failed to join a test thread");
}
#[test]
// As of Oct 2022 there appears to be a bug in halo2 which causes the prove
// times with MockProver be as long as with a real prover. Marking this test
// as expensive.
#[ignore = "expensive"]
/// Tests that a corrupted proof causes verification to fail
fn halo2_e2e_test_failure() {
halo2_e2e_test(true);
// This test causes the "thread ... has overflowed its stack" error
// The only way to increase the stack size is to spawn a new thread with
// the test.
// See https://github.com/rust-lang/rustfmt/issues/3473
use std::thread;
thread::Builder::new()
.stack_size(8388608)
.spawn(|| halo2_e2e_test(true))
.expect("Failed to create a test thread")
.join()
.expect("Failed to join a test thread");
}
}

View File

@@ -7,7 +7,6 @@ use halo2_proofs::plonk;
use halo2_proofs::plonk::ProvingKey;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::transcript::{Blake2bWrite, Challenge255};
use instant::Instant;
use num::BigUint;
use pasta_curves::pallas::Base as F;
use pasta_curves::EqAffine;
@@ -60,8 +59,6 @@ impl Prove for Prover {
];
all_inputs.push(tmp);
let now = Instant::now();
// prepare the proving system and generate the proof:
let circuit =
@@ -74,6 +71,8 @@ impl Prove for Prover {
let mut rng = thread_rng();
// let now = Instant::now();
let res = plonk::create_proof(
params,
pk,
@@ -86,9 +85,9 @@ impl Prove for Prover {
return Err(ProverError::ProvingBackendError);
}
println!("Proof created [{:?}]", now.elapsed());
// println!("Proof created [{:?}]", now.elapsed());
let proof = transcript.finalize();
println!("Proof size [{} kB]", proof.len() as f64 / 1024.0);
// println!("Proof size [{} kB]", proof.len() as f64 / 1024.0);
Ok(proof)
}
@@ -154,8 +153,9 @@ fn hash_internal(inputs: &Vec<BigUint>) -> Result<BigUint, ProverError> {
#[cfg(test)]
mod tests {
use super::*;
use crate::halo2_backend::circuit::{CELLS_PER_ROW, K, USEFUL_ROWS};
use crate::halo2_backend::circuit::{CELLS_PER_ROW, K};
use crate::halo2_backend::prover::hash_internal;
use crate::halo2_backend::utils::bigint_to_256bits;
use crate::halo2_backend::Curve;
use crate::prover::{ProofInput, Prove, ProverError};
use crate::tests::run_until_proofs_are_generated;
@@ -163,7 +163,6 @@ mod tests {
use crate::Proof;
use halo2_proofs::dev::MockProver;
use num::BigUint;
use std::panic::catch_unwind;
/// TestHalo2Prover is a test prover. It is the same as [Prover] except:
/// - it doesn't require a proving key
@@ -207,41 +206,52 @@ mod tests {
);
// Test with the correct inputs.
// Expect successful verification.
println!("start mockprover");
let prover = MockProver::run(K, &circuit, good_inputs.clone()).unwrap();
assert!(prover.verify().is_ok());
println!("end mockprover");
// Corrupt at least one delta which corresponds to plaintext bit 1.
// Since the plaintext was chosen randomly, we corrupt only the last
// deltas on each row - one of those deltas will correspond to a plaintext
// bit 1 with high probability.
// Expect halo2 to panic.
let mut bad_input1 = good_inputs.clone();
for i in 0..USEFUL_ROWS {
bad_input1[CELLS_PER_ROW - 1][i] = F::from(123);
// Find one delta which corresponds to plaintext bit 1 and corrupt
// the delta:
// Find the first bit 1 in plaintext
let bits = bigint_to_256bits(input.plaintext[0].clone());
let mut offset: i32 = -1;
for (i, b) in bits.iter().enumerate() {
if *b == true {
offset = i as i32;
break;
}
}
// first field element of the plaintext is not expected to have all
// bits set to zero.
assert!(offset != -1);
let offset = offset as usize;
// Find the position of the corresponding delta. The position is
// row/column in the halo2 table
let col = offset % CELLS_PER_ROW;
let row = offset / CELLS_PER_ROW;
// Corrupt the delta
let mut bad_input1 = good_inputs.clone();
bad_input1[col][row] = F::from(123);
println!("start mockprover2");
let prover = MockProver::run(K, &circuit, bad_input1.clone()).unwrap();
assert!(prover.verify().is_err());
println!("end mockprover2");
// One-by-one corrupt the plaintext hash, the label sum hash, the zero sum.
// Expect halo2 to panic.
// Expect verification error.
for i in 0..3 {
let mut bad_public_input = good_inputs.clone();
bad_public_input[CELLS_PER_ROW][i] = F::from(123);
println!("start mockprover3");
let prover = MockProver::run(K, &circuit, bad_public_input.clone()).unwrap();
assert!(prover.verify().is_err());
println!("end mockprover3");
}
// Corrupt only the plaintext.
// Expect halo2 to panic.
// Expect verification error.
let mut bad_plaintext = good_plaintext.clone();
bad_plaintext[0] = F::from(123);
@@ -250,13 +260,11 @@ mod tests {
bigint_to_f(&input.salt),
deltas_as_rows.into(),
);
println!("start mockprover4");
let prover = MockProver::run(K, &circuit, good_inputs.clone()).unwrap();
assert!(prover.verify().is_err());
println!("end mockprover4");
// Corrupt only the salt.
// Expect halo2 to panic.
// Expect verification error.
let bad_salt = BigUint::from(123u8);
let circuit = AuthDecodeCircuit::new(
@@ -264,10 +272,8 @@ mod tests {
bigint_to_f(&bad_salt),
deltas_as_rows.into(),
);
println!("start mockprover5");
let prover = MockProver::run(K, &circuit, good_inputs.clone()).unwrap();
assert!(prover.verify().is_err());
println!("end mockprover5");
Ok(Default::default())
}
@@ -325,7 +331,6 @@ mod tests {
match self.curve {
Curve::Pallas => 255,
Curve::BN254 => 254,
_ => panic!("a new curve was added. Add its field size here."),
}
}
@@ -339,11 +344,27 @@ mod tests {
}
#[test]
// As of Oct 2022 there appears to be a bug in halo2 which causes the prove
// times with MockProver be as long as with a real prover. Marking this test
// as expensive.
#[ignore = "expensive"]
/// Tests the circuit with the correct inputs as well as wrong inputs. The logic is
/// in [TestHalo2Prover]'s prove()
fn test_circuit() {
let prover = Box::new(TestHalo2Prover::new());
let verifier = Box::new(TestHalo2Verifier::new(Curve::Pallas));
let _res = run_until_proofs_are_generated(prover, verifier);
// This test causes the "thread ... has overflowed its stack" error
// The only way to increase the stack size is to spawn a new thread with
// the test.
// See https://github.com/rust-lang/rustfmt/issues/3473
use std::thread;
thread::Builder::new()
.stack_size(8388608)
.spawn(|| {
let prover = Box::new(TestHalo2Prover::new());
let verifier = Box::new(TestHalo2Verifier::new(Curve::Pallas));
let _ = run_until_proofs_are_generated(prover, verifier);
})
.expect("Failed to create a test thread")
.join()
.expect("Failed to join a test thread");
}
}

View File

@@ -2,7 +2,7 @@ use super::circuit::{CELLS_PER_ROW, USEFUL_ROWS};
use crate::utils::{boolvec_to_u8vec, u8vec_to_boolvec};
use crate::Delta;
use halo2_proofs::arithmetic::FieldExt;
use num::{BigUint, FromPrimitive};
use num::BigUint;
use pasta_curves::Fp as F;
/// Decomposes a `BigUint` into bits and returns the bits in MSB-first bit order,
@@ -86,7 +86,7 @@ fn test_f_to_bigint() {
let b = rng.gen::<u128>();
let res = f_to_bigint(&(F::from_u128(a) + F::from_u128(b)));
let expected: BigUint = BigUint::from_u128(a).unwrap() + BigUint::from_u128(b).unwrap();
let expected: BigUint = BigUint::from(a) + BigUint::from(b);
assert_eq!(res, expected);
}

View File

@@ -7,7 +7,6 @@ use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::transcript::Blake2bRead;
use halo2_proofs::transcript::Challenge255;
use instant::Instant;
use pasta_curves::pallas::Base as F;
use pasta_curves::EqAffine;
@@ -55,7 +54,7 @@ impl Verify for Verifier {
];
all_inputs.push(tmp);
let now = Instant::now();
// let now = Instant::now();
// perform the actual verification
let res = plonk::verify_proof(
params,
@@ -64,7 +63,7 @@ impl Verify for Verifier {
&[all_inputs.as_slice()],
&mut transcript,
);
println!("Proof verified [{:?}]", now.elapsed());
// println!("Proof verified [{:?}]", now.elapsed());
if res.is_err() {
return Err(VerifierError::VerificationFailed);
} else {
@@ -76,7 +75,6 @@ impl Verify for Verifier {
match self.curve {
Curve::Pallas => 255,
Curve::BN254 => 254,
_ => panic!("a new curve was added. Add its field size here."),
}
}

View File

@@ -133,7 +133,7 @@ mod tests {
// generate random plaintext of random size up to 1000 bytes
let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
.take(thread_rng().gen_range(0..1000))
.take(thread_rng().gen_range(1..1000))
.collect();
// Normally, the Prover is expected to obtain her binary labels by

View File

@@ -182,6 +182,7 @@ pub trait Prove {
fn hash(&self, inputs: &Vec<BigUint>) -> Result<BigUint, ProverError>;
}
/// Implementation of the prover in the AuthDecode protocol.
pub struct AuthDecodeProver<S = Setup>
where
S: State,
@@ -831,13 +832,13 @@ mod tests {
#[test]
/// Sets too few binary labels and triggers [ProverError::IncorrectBinaryLabelSize]
fn test_error_incorrect_binary_label_size() {
let pt_len = 1000;
let ciphertexts = vec![[[0u8; 16], [0u8; 16]]; pt_len * 8];
let plaintext_size = 1000;
let ciphertexts = vec![[[0u8; 16], [0u8; 16]]; plaintext_size];
let labels = vec![0u128];
let lsp = AuthDecodeProver {
state: LabelSumCommitment {
plaintext_size: 0,
plaintext_size,
..Default::default()
},
prover: Box::new(CorrectTestProver {}),
@@ -846,7 +847,7 @@ mod tests {
assert_eq!(
res.err().unwrap(),
ProverError::IncorrectBinaryLabelSize(pt_len, 1)
ProverError::IncorrectBinaryLabelSize(plaintext_size, 1)
);
}

View File

@@ -1,6 +1,5 @@
use crate::{Delta, ZeroSum};
use aes::{Aes128, NewBlockCipher};
use ark_ff::BigInt;
use cipher::{consts::U16, generic_array::GenericArray, BlockEncrypt};
use num::BigUint;
use sha2::{Digest, Sha256};

View File

@@ -62,11 +62,6 @@ pub struct VerifyMany {
}
impl State for VerifyMany {}
pub struct VerificationSuccessfull {
plaintext_hashes: Vec<PlaintextHash>,
}
impl State for VerificationSuccessfull {}
pub trait Verify {
/// Verifies the zk proof against public `input`s. Returns `true` on success,
/// `false` otherwise.
@@ -85,6 +80,8 @@ pub trait Verify {
/// of the last field element of each chunk.
fn chunk_size(&self) -> usize;
}
/// Implementation of the verifier in the AuthDecode protocol.
pub struct AuthDecodeVerifier<S = Setup>
where
S: State,