fix(zk): add a size check for the public key

This commit is contained in:
Nicolas Sarlin
2025-09-09 17:30:04 +02:00
committed by Nicolas Sarlin
parent 01651d6fb2
commit bfbf638fed
3 changed files with 220 additions and 6 deletions

View File

@@ -154,6 +154,8 @@ pub(crate) enum ProofSanityCheckMode {
/// does not hold.
#[allow(clippy::too_many_arguments)]
fn assert_pke_proof_preconditions(
a: &[i64],
b: &[i64],
c1: &[i64],
e1: &[i64],
c2: &[i64],
@@ -164,6 +166,8 @@ fn assert_pke_proof_preconditions(
big_d_max: usize,
) {
assert!(k_max <= d);
assert_eq!(a.len(), d);
assert_eq!(b.len(), d);
assert_eq!(c1.len(), d);
assert_eq!(e1.len(), d);
@@ -442,6 +446,13 @@ mod test {
No,
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub(super) enum InputSizeVariation {
Oversized,
Undersized,
Nominal,
}
pub(super) fn serialize_then_deserialize<
Params: Compressible + Serialize + for<'de> Deserialize<'de>,
>(

View File

@@ -607,7 +607,7 @@ fn prove_impl<G: Curve>(
+ (d + k) * (2 + b_i.ilog2() as usize + b_r.ilog2() as usize);
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, big_d, big_d_max);
assert_pke_proof_preconditions(a, b, c1, e1, c2, e2, d, k_max, big_d, big_d_max);
}
// FIXME: div_round
@@ -1095,6 +1095,10 @@ pub fn verify<G: Curve>(
return Err(());
}
if a.len() != d || b.len() != d {
return Err(());
}
let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count;
let big_d = d
@@ -1787,6 +1791,94 @@ mod tests {
}
}
/// Test that the proof is rejected without panic if the public key elements are not of the
/// correct size
#[test]
fn test_pke_wrong_pk_size() {
let PkeTestParameters {
d,
k,
B,
q,
t,
msbs_zero_padding_bit_count,
} = PKEV1_TEST_PARAMS;
let seed = thread_rng().gen();
println!("pke_wrong_pk_size seed: {seed:x}");
let rng = &mut StdRng::seed_from_u64(seed);
let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS);
let ct = testcase.encrypt(PKEV1_TEST_PARAMS);
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
let crs = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
let (public_commit, private_commit) = commit(
testcase.a.clone(),
testcase.b.clone(),
ct.c1.clone(),
ct.c2.clone(),
testcase.r.clone(),
testcase.e1.clone(),
testcase.m.clone(),
testcase.e2.clone(),
&crs,
);
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof = prove(
(&crs, &public_commit),
&private_commit,
&testcase.metadata,
load,
&seed.to_le_bytes(),
);
for (a_size_kind, b_size_kind) in itertools::iproduct!(
[
InputSizeVariation::Oversized,
InputSizeVariation::Undersized,
InputSizeVariation::Nominal,
],
[
InputSizeVariation::Oversized,
InputSizeVariation::Undersized,
InputSizeVariation::Nominal,
]
) {
if a_size_kind == InputSizeVariation::Nominal
&& b_size_kind == InputSizeVariation::Nominal
{
// This is the nominal case that is already tested
continue;
}
let mut public_commit = public_commit.clone();
match a_size_kind {
InputSizeVariation::Oversized => public_commit.a.push(rng.gen()),
InputSizeVariation::Undersized => {
public_commit.a.pop();
}
InputSizeVariation::Nominal => {}
};
match b_size_kind {
InputSizeVariation::Oversized => public_commit.b.push(rng.gen()),
InputSizeVariation::Undersized => {
public_commit.b.pop();
}
InputSizeVariation::Nominal => {}
};
// Should not panic but return an error
assert!(verify(&proof, (&crs, &public_commit), &testcase.metadata).is_err())
}
}
}
/// Test verification with modified ciphertexts
#[test]
fn test_bad_ct() {

View File

@@ -893,7 +893,7 @@ fn prove_impl<G: Curve>(
.sum::<u128>();
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, D, D_max);
assert_pke_proof_preconditions(a, b, c1, e1, c2, e2, d, k_max, D, D_max);
assert!(
B_squared >= e_sqr_norm,
"squared norm of error ({e_sqr_norm}) exceeds threshold ({B_squared})",
@@ -1076,7 +1076,16 @@ fn prove_impl<G: Curve>(
let (theta, theta_hash) = t_hash.gen_theta();
let mut a_theta = vec![G::Zp::ZERO; D];
compute_a_theta::<G>(&mut a_theta, &theta, a, k, b, effective_cleartext_t, delta);
compute_a_theta::<G>(
&mut a_theta,
&theta,
a,
d,
k,
b,
effective_cleartext_t,
delta,
);
let t_theta = theta
.iter()
@@ -1606,6 +1615,7 @@ fn compute_a_theta<G: Curve>(
a_theta: &mut [G::Zp],
theta: &[G::Zp],
a: &[i64],
d: usize,
k: usize,
b: &[i64],
t: u64,
@@ -1620,8 +1630,8 @@ fn compute_a_theta<G: Curve>(
// ...
// delta g[log t].T theta2_k
// ]
let d = a.len();
assert_eq!(a.len(), d);
assert!(theta.len() >= d);
let theta1 = &theta[..d];
let theta2 = &theta[d..];
@@ -1770,6 +1780,10 @@ pub fn verify_impl<G: Curve>(
return Err(());
}
if a.len() != d || b.len() != d {
return Err(());
}
let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count;
let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
let (_, D, _, m_bound) = compute_crs_params(
@@ -1815,7 +1829,16 @@ pub fn verify_impl<G: Curve>(
let (theta, theta_hash) = t_hash.gen_theta();
let mut a_theta = vec![G::Zp::ZERO; D];
compute_a_theta::<G>(&mut a_theta, &theta, a, k, b, effective_cleartext_t, delta);
compute_a_theta::<G>(
&mut a_theta,
&theta,
a,
d,
k,
b,
effective_cleartext_t,
delta,
);
let t_theta = theta
.iter()
@@ -2778,6 +2801,94 @@ mod tests {
}
}
/// Test that the proof is rejected without panic if the public key elements are not of the
/// correct size
#[test]
fn test_pke_wrong_pk_size() {
let PkeTestParameters {
d,
k,
B,
q,
t,
msbs_zero_padding_bit_count,
} = PKEV2_TEST_PARAMS;
let seed = thread_rng().gen();
println!("pke_wrong_pk_size seed: {seed:x}");
let rng = &mut StdRng::seed_from_u64(seed);
let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
let crs = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
let (public_commit, private_commit) = commit(
testcase.a.clone(),
testcase.b.clone(),
ct.c1.clone(),
ct.c2.clone(),
testcase.r.clone(),
testcase.e1.clone(),
testcase.m.clone(),
testcase.e2.clone(),
&crs,
);
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof = prove(
(&crs, &public_commit),
&private_commit,
&testcase.metadata,
load,
&seed.to_le_bytes(),
);
for (a_size_kind, b_size_kind) in itertools::iproduct!(
[
InputSizeVariation::Oversized,
InputSizeVariation::Undersized,
InputSizeVariation::Nominal,
],
[
InputSizeVariation::Oversized,
InputSizeVariation::Undersized,
InputSizeVariation::Nominal,
]
) {
if a_size_kind == InputSizeVariation::Nominal
&& b_size_kind == InputSizeVariation::Nominal
{
// This is the nominal case that is already tested
continue;
}
let mut public_commit = public_commit.clone();
match a_size_kind {
InputSizeVariation::Oversized => public_commit.a.push(rng.gen()),
InputSizeVariation::Undersized => {
public_commit.a.pop();
}
InputSizeVariation::Nominal => {}
};
match b_size_kind {
InputSizeVariation::Oversized => public_commit.b.push(rng.gen()),
InputSizeVariation::Undersized => {
public_commit.b.pop();
}
InputSizeVariation::Nominal => {}
};
// Should not panic but return an error
assert!(verify(&proof, (&crs, &public_commit), &testcase.metadata).is_err())
}
}
}
/// Test verification with modified ciphertexts
#[test]
fn test_bad_ct() {