refactor(zk): place asserts in proof behind a condition

This commit is contained in:
Nicolas Sarlin
2024-11-14 17:05:09 +01:00
committed by Nicolas Sarlin
parent 1e19bae29a
commit 770ae22bb6
4 changed files with 94 additions and 18 deletions

View File

@@ -214,6 +214,11 @@ impl Montgomery {
pub fn four_squares(v: u128) -> [u64; 4] {
let rng = &mut StdRng::seed_from_u64(0);
// In the extreme case where the noise is exactly at the bound, v is 0
if v == 0 {
return [0; 4];
}
let f = v % 4;
if f == 2 {
let b = isqrt(v as _) as u64;

View File

@@ -132,6 +132,34 @@ impl<G: Curve> GroupElements<G> {
}
}
/// Allows to compute proof with bad inputs for tests
#[derive(PartialEq, Eq)]
enum ProofSanityCheckMode {
Panic,
}
/// Check the preconditions of the pke proof before computing it. Panic if one of the conditions
/// does not hold.
#[allow(clippy::too_many_arguments)]
fn assert_pke_proof_preconditions(
c1: &[i64],
e1: &[i64],
c2: &[i64],
e2: &[i64],
d: usize,
k_max: usize,
big_d: usize,
big_d_max: usize,
) {
assert_eq!(c1.len(), d);
assert_eq!(e1.len(), d);
assert_eq!(c2.len(), e2.len());
assert!(c2.len() <= k_max);
assert!(big_d <= big_d_max);
}
/// q (modulus) is encoded on 64b, with 0 meaning 2^64. This converts the encoded q to its effective
/// value for modular operations.
fn decode_q(q: u64) -> u128 {

View File

@@ -475,6 +475,24 @@ pub fn prove<G: Curve>(
metadata: &[u8],
load: ComputeLoad,
rng: &mut dyn RngCore,
) -> Proof<G> {
prove_impl(
public,
private_commit,
metadata,
load,
rng,
ProofSanityCheckMode::Panic,
)
}
fn prove_impl<G: Curve>(
public: (&PublicParams<G>, &PublicCommit<G>),
private_commit: &PrivateCommit<G>,
metadata: &[u8],
load: ComputeLoad,
rng: &mut dyn RngCore,
sanity_check_mode: ProofSanityCheckMode,
) -> Proof<G> {
let &PublicParams {
ref g_lists,
@@ -503,7 +521,6 @@ pub fn prove<G: Curve>(
let PrivateCommit { r, e1, m, e2, .. } = private_commit;
let k = c2.len();
assert!(k <= k_max);
let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count;
@@ -512,7 +529,10 @@ pub fn prove<G: Curve>(
let big_d = d
+ k * effective_t_for_decomposition.ilog2() as usize
+ (d + k) * (2 + b_i.ilog2() as usize + b_r.ilog2() as usize);
assert!(big_d <= big_d_max);
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, big_d, big_d_max);
}
// FIXME: div_round
let delta = {

View File

@@ -654,6 +654,24 @@ pub fn prove<G: Curve>(
metadata: &[u8],
load: ComputeLoad,
rng: &mut dyn RngCore,
) -> Proof<G> {
prove_impl(
public,
private_commit,
metadata,
load,
rng,
ProofSanityCheckMode::Panic,
)
}
fn prove_impl<G: Curve>(
public: (&PublicParams<G>, &PublicCommit<G>),
private_commit: &PrivateCommit<G>,
metadata: &[u8],
load: ComputeLoad,
rng: &mut dyn RngCore,
sanity_check_mode: ProofSanityCheckMode,
) -> Proof<G> {
_ = load;
let (
@@ -689,7 +707,6 @@ pub fn prove<G: Curve>(
let PrivateCommit { r, e1, m, e2, .. } = private_commit;
let k = c2.len();
assert!(k <= k_max);
let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count;
@@ -698,7 +715,21 @@ pub fn prove<G: Curve>(
// Recompute the D for our case if k is smaller than the k max
// formula in Prove_pp: 2.
let D = d + k * effective_cleartext_t.ilog2() as usize;
assert!(D <= D_max);
let e_sqr_norm = e1
.iter()
.chain(e2)
.map(|x| sqr(x.unsigned_abs() as u128))
.sum::<u128>();
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, D, D_max);
assert!(
sqr(B as u128) >= e_sqr_norm,
"squared norm of error ({e_sqr_norm}) exceeds threshold ({})",
sqr(B as u128)
);
}
// FIXME: div_round
let delta = {
@@ -730,18 +761,6 @@ pub fn prove<G: Curve>(
)
.collect::<Box<[_]>>();
let e_sqr_norm = e1
.iter()
.chain(e2)
.map(|x| sqr(x.unsigned_abs() as u128))
.sum::<u128>();
assert!(
sqr(B as u128) >= e_sqr_norm,
"squared norm of error ({e_sqr_norm}) exceeds threshold ({})",
sqr(B as u128)
);
let v = four_squares(sqr(B as u128) - e_sqr_norm).map(|v| v as i64);
let e1_zp = &*e1
@@ -872,7 +891,9 @@ pub fn prove<G: Curve>(
-1 => acc -= x as i128,
_ => unreachable!(),
});
assert!(acc.unsigned_abs() <= B_bound as u128);
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert!(acc.unsigned_abs() <= B_bound as u128);
}
acc as i64
})
.collect::<Box<[_]>>();
@@ -970,7 +991,9 @@ pub fn prove<G: Curve>(
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
.collect::<Box<[_]>>();
assert_eq!(y.len(), w_bin.len());
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert_eq!(y.len(), w_bin.len());
}
let scalars = y
.iter()
.zip(w_bin.iter())