chore(zk): small refactor of tests to use assert_prove_and_verify

This commit is contained in:
Nicolas Sarlin
2024-11-22 17:36:42 +01:00
committed by Nicolas Sarlin
parent 530b18063a
commit 81f071c30e
3 changed files with 40 additions and 58 deletions

View File

@@ -133,7 +133,7 @@ impl<G: Curve> GroupElements<G> {
} }
/// Allows to compute proof with bad inputs for tests /// Allows to compute proof with bad inputs for tests
#[derive(PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq)]
enum ProofSanityCheckMode { enum ProofSanityCheckMode {
Panic, Panic,
#[cfg(test)] #[cfg(test)]

View File

@@ -1262,6 +1262,8 @@ mod tests {
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
type Curve = curve_api::Bls12_446;
/// Compact key params used with pkev1 /// Compact key params used with pkev1
pub(super) const PKEV1_TEST_PARAMS: PkeTestParameters = PkeTestParameters { pub(super) const PKEV1_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
d: 1024, d: 1024,
@@ -1310,8 +1312,6 @@ mod tests {
let mut fake_metadata = [255u8; METADATA_LEN]; let mut fake_metadata = [255u8; METADATA_LEN];
fake_metadata.fill_with(|| rng.gen::<u8>()); fake_metadata.fill_with(|| rng.gen::<u8>());
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test // To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k)); let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -1397,9 +1397,9 @@ mod tests {
} }
} }
fn prove_and_verify<G: Curve>( fn prove_and_verify(
testcase: &PkeTestcase, testcase: &PkeTestcase,
crs: &PublicParams<G>, crs: &PublicParams<Curve>,
load: ComputeLoad, load: ComputeLoad,
rng: &mut StdRng, rng: &mut StdRng,
) -> VerificationResult { ) -> VerificationResult {
@@ -1434,10 +1434,10 @@ mod tests {
} }
} }
fn assert_prove_and_verify<G: Curve>( fn assert_prove_and_verify(
testcase: &PkeTestcase, testcase: &PkeTestcase,
testcase_name: &str, testcase_name: &str,
crs: &PublicParams<G>, crs: &PublicParams<Curve>,
rng: &mut StdRng, rng: &mut StdRng,
expected_result: VerificationResult, expected_result: VerificationResult,
) { ) {
@@ -1466,8 +1466,6 @@ mod tests {
let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS); let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
// A CRS where the number of slots = the number of messages to encrypt // A CRS where the number of slots = the number of messages to encrypt
let crs = crs_gen::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng); let crs = crs_gen::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng);
@@ -1630,7 +1628,6 @@ mod tests {
}; };
let ct = testcase.encrypt(PKEV1_TEST_PARAMS); let ct = testcase.encrypt(PKEV1_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test // To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k)); let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));

View File

@@ -2412,6 +2412,8 @@ mod tests {
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
type Curve = curve_api::Bls12_446;
/// Compact key params used with pkev2 /// Compact key params used with pkev2
pub(super) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters { pub(super) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
d: 2048, d: 2048,
@@ -2459,8 +2461,6 @@ mod tests {
let mut fake_metadata = [255u8; METADATA_LEN]; let mut fake_metadata = [255u8; METADATA_LEN];
fake_metadata.fill_with(|| rng.gen::<u8>()); fake_metadata.fill_with(|| rng.gen::<u8>());
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test // To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k)); let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -2546,14 +2546,14 @@ mod tests {
} }
} }
fn prove_and_verify<G: Curve>( fn prove_and_verify(
testcase: &PkeTestcase, testcase: &PkeTestcase,
crs: &PublicParams<G>, ct: &PkeTestCiphertext,
crs: &PublicParams<Curve>,
load: ComputeLoad, load: ComputeLoad,
sanity_check_mode: ProofSanityCheckMode,
rng: &mut StdRng, rng: &mut StdRng,
) -> VerificationResult { ) -> VerificationResult {
let ct = testcase.encrypt_unchecked(PKEV2_TEST_PARAMS);
let (public_commit, private_commit) = commit( let (public_commit, private_commit) = commit(
testcase.a.clone(), testcase.a.clone(),
testcase.b.clone(), testcase.b.clone(),
@@ -2573,7 +2573,7 @@ mod tests {
&testcase.metadata, &testcase.metadata,
load, load,
rng, rng,
ProofSanityCheckMode::Ignore, sanity_check_mode,
); );
if verify(&proof, (crs, &public_commit), &testcase.metadata).is_ok() { if verify(&proof, (crs, &public_commit), &testcase.metadata).is_ok() {
@@ -2583,16 +2583,18 @@ mod tests {
} }
} }
fn assert_prove_and_verify<G: Curve>( fn assert_prove_and_verify(
testcase: &PkeTestcase, testcase: &PkeTestcase,
ct: &PkeTestCiphertext,
testcase_name: &str, testcase_name: &str,
crs: &PublicParams<G>, crs: &PublicParams<Curve>,
rng: &mut StdRng, sanity_check_mode: ProofSanityCheckMode,
expected_result: VerificationResult, expected_result: VerificationResult,
rng: &mut StdRng,
) { ) {
for load in [ComputeLoad::Proof, ComputeLoad::Verify] { for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
assert_eq!( assert_eq!(
prove_and_verify(testcase, crs, load, rng), prove_and_verify(testcase, ct, crs, load, sanity_check_mode, rng),
expected_result, expected_result,
"Testcase {testcase_name} failed" "Testcase {testcase_name} failed"
) )
@@ -2785,8 +2787,6 @@ mod tests {
let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
let crs = crs_gen::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng); let crs = crs_gen::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng);
let crs_max_k = crs_gen::<Curve>(d, d, B, q, t, msbs_zero_padding_bit_count, rng); let crs_max_k = crs_gen::<Curve>(d, d, B, q, t, msbs_zero_padding_bit_count, rng);
@@ -2848,19 +2848,24 @@ mod tests {
expected_result, expected_result,
} in testcases } in testcases
{ {
let ct = testcase.encrypt_unchecked(PKEV2_TEST_PARAMS);
assert_prove_and_verify( assert_prove_and_verify(
&testcase, &testcase,
&ct,
&format!("{name}_crs"), &format!("{name}_crs"),
&crs, &crs,
rng, ProofSanityCheckMode::Ignore,
expected_result, expected_result,
rng,
); );
assert_prove_and_verify( assert_prove_and_verify(
&testcase, &testcase,
&ct,
&format!("{name}_crs_max_k"), &format!("{name}_crs_max_k"),
&crs_max_k, &crs_max_k,
rng, ProofSanityCheckMode::Ignore,
expected_result, expected_result,
rng,
); );
} }
} }
@@ -2926,8 +2931,6 @@ mod tests {
let ct = testcase.encrypt(PKEV2_TEST_PARAMS); let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test // To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k)); let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -2938,37 +2941,23 @@ mod tests {
let public_param_that_was_not_compressed = let public_param_that_was_not_compressed =
serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
for public_param in [ for (public_param, test_name) in [
original_public_param, (original_public_param, "original_params"),
public_param_that_was_compressed, (
public_param_that_was_not_compressed, public_param_that_was_compressed,
"serialized_compressed_params",
),
(public_param_that_was_not_compressed, "serialize_params"),
] { ] {
let (public_commit, private_commit) = commit( assert_prove_and_verify(
testcase.a.clone(), &testcase,
testcase.b.clone(), &ct,
ct.c1.clone(), test_name,
ct.c2.clone(),
testcase.r.clone(),
testcase.e1.clone(),
testcase.m.clone(),
testcase.e2.clone(),
&public_param, &public_param,
ProofSanityCheckMode::Panic,
VerificationResult::Reject,
rng, rng,
); );
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof = prove(
(&public_param, &public_commit),
&private_commit,
&testcase.metadata,
load,
rng,
);
assert!(
verify(&proof, (&public_param, &public_commit), &testcase.metadata).is_err()
);
}
} }
} }
@@ -2989,8 +2978,6 @@ mod tests {
let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
let ct = testcase.encrypt(PKEV2_TEST_PARAMS); let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k)); let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
@@ -3042,8 +3029,6 @@ mod tests {
let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
let ct = testcase.encrypt(PKEV2_TEST_PARAMS); let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k)); let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);