feat(zk): check that proof and crs points are valid

This commit is contained in:
Nicolas Sarlin
2024-10-28 14:05:47 +01:00
committed by Nicolas Sarlin
parent 5c42fc950e
commit 295b6608ee
6 changed files with 532 additions and 6 deletions

View File

@@ -116,6 +116,10 @@ pub trait CurveGroupOps<Zp>:
fn to_le_bytes(self) -> impl AsRef<[u8]>;
fn double(self) -> Self;
fn normalize(self) -> Self::Affine;
fn validate_projective(&self) -> bool {
Self::validate_affine(&self.normalize())
}
fn validate_affine(affine: &Self::Affine) -> bool;
}
/// Mark that an element can be compressed, by storing only the 'x' coordinates of the affine
@@ -231,6 +235,10 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
inner: self.inner.into_affine(),
}
}
fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}
impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
@@ -271,6 +279,10 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
inner: self.inner.into_affine(),
}
}
fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}
impl PairingGroupOps<bls12_381::Zp, bls12_381::G1, bls12_381::G2> for bls12_381::Gt {
@@ -368,6 +380,10 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
inner: self.inner.into_affine(),
}
}
fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}
impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
@@ -408,6 +424,10 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
inner: self.inner.into_affine(),
}
}
fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}
impl PairingGroupOps<bls12_446::Zp, bls12_446::G1, bls12_446::G2> for bls12_446::Gt {

View File

@@ -90,6 +90,10 @@ mod g1 {
.unwrap(),
}
}
pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}
}
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
@@ -310,6 +314,10 @@ mod g2 {
.unwrap(),
}
}
pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}
}
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]

View File

@@ -92,6 +92,10 @@ mod g1 {
.unwrap(),
}
}
pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}
}
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
@@ -316,6 +320,10 @@ mod g2 {
}
}
pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}
// m is an intermediate variable that's used in both the curve point addition and pairing
// functions. we cache it since it requires a Zp division
// https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition

View File

@@ -1,4 +1,5 @@
use crate::backward_compatibility::GroupElementsVersions;
use crate::curve_api::{Compressible, Curve, CurveGroupOps, FieldOps, PairingGroupOps};
use crate::serialization::{
InvalidSerializedGroupElementsError, SerializableG1Affine, SerializableG2Affine,
@@ -6,6 +7,7 @@ use crate::serialization::{
};
use core::ops::{Index, IndexMut};
use rand::{Rng, RngCore};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use tfhe_versionable::Versionize;
#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, Versionize)]
@@ -108,6 +110,16 @@ impl<G: Curve> GroupElements<G> {
message_len,
}
}
/// Check if the elements are valid for their respective groups
pub fn is_valid(&self) -> bool {
let (g_list_valid, g_hat_list_valid) = rayon::join(
|| self.g_list.0.par_iter().all(G::G1::validate_affine),
|| self.g_hat_list.0.par_iter().all(G::G2::validate_affine),
);
g_list_valid && g_hat_list_valid
}
}
impl<G: Curve> Compressible for GroupElements<G>
@@ -152,6 +164,8 @@ mod test {
#![allow(non_snake_case)]
use std::fmt::Display;
use ark_ec::{short_weierstrass, CurveConfig};
use ark_ff::UniformRand;
use bincode::ErrorKind;
use rand::rngs::StdRng;
use rand::Rng;
@@ -359,4 +373,47 @@ mod test {
PkeTestCiphertext { c1, c2 }
}
}
/// Return a point with coordinates (x, y) that is randomly chosen and not on the curve
pub(super) fn point_not_on_curve<Config: short_weierstrass::SWCurveConfig>(
rng: &mut StdRng,
) -> short_weierstrass::Affine<Config> {
loop {
let fake_x = <Config as CurveConfig>::BaseField::rand(rng);
let fake_y = <Config as CurveConfig>::BaseField::rand(rng);
let point = short_weierstrass::Affine::new_unchecked(fake_x, fake_y);
if !point.is_on_curve() {
return point;
}
}
}
/// Return a random point on the curve
pub(super) fn point_on_curve<Config: short_weierstrass::SWCurveConfig>(
rng: &mut StdRng,
) -> short_weierstrass::Affine<Config> {
loop {
let x = <Config as CurveConfig>::BaseField::rand(rng);
let is_positive = bool::rand(rng);
if let Some(point) =
short_weierstrass::Affine::get_point_from_x_unchecked(x, is_positive)
{
return point;
}
}
}
/// Return a random point that is on the curve but not in the correct subgroup
pub(super) fn point_on_curve_wrong_subgroup<Config: short_weierstrass::SWCurveConfig>(
rng: &mut StdRng,
) -> short_weierstrass::Affine<Config> {
loop {
let point = point_on_curve(rng);
if !Config::is_in_correct_subgroup_assuming_on_curve(&point) {
return point;
}
}
}
}

View File

@@ -182,6 +182,15 @@ impl<G: Curve> PublicParams<G> {
pub fn exclusive_max_noise(&self) -> u64 {
self.b
}
/// Check if the crs can be used to generate or verify a proof
///
/// This means checking that the points are:
/// - valid points of the curve
/// - in the correct subgroup
pub fn is_usable(&self) -> bool {
self.g_lists.is_valid()
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)]
@@ -197,6 +206,38 @@ pub struct Proof<G: Curve> {
pub(crate) compute_load_proof_fields: Option<ComputeLoadProofFields<G>>,
}
impl<G: Curve> Proof<G> {
/// Check if the proof can be used by the Verifier.
///
/// This means checking that the points in the proof are:
/// - valid points of the curve
/// - in the correct subgroup
pub fn is_usable(&self) -> bool {
let &Proof {
c_hat,
c_y,
pi,
ref compute_load_proof_fields,
} = self;
c_hat.validate_projective()
&& c_y.validate_projective()
&& pi.validate_projective()
&& compute_load_proof_fields.as_ref().map_or(
true,
|&ComputeLoadProofFields {
c_hat_t,
c_h,
pi_kzg,
}| {
c_hat_t.validate_projective()
&& c_h.validate_projective()
&& pi_kzg.validate_projective()
},
)
}
}
/// These fields can be pre-computed on the prover side in the faster Verifier scheme. If that's the
/// case, they should be included in the proof.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)]
@@ -1260,6 +1301,8 @@ pub fn verify<G: Curve>(
#[cfg(test)]
mod tests {
use crate::curve_api::{self, bls12_446};
use super::super::test::*;
use super::*;
use rand::rngs::StdRng;
@@ -1312,7 +1355,7 @@ mod tests {
let mut fake_metadata = [255u8; METADATA_LEN];
fake_metadata.fill_with(|| rng.gen::<u8>());
type Curve = crate::curve_api::Bls12_446;
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -1429,7 +1472,7 @@ mod tests {
};
let ct = testcase.encrypt(PKEV1_TEST_PARAMS);
type Curve = crate::curve_api::Bls12_446;
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -1491,7 +1534,7 @@ mod tests {
let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS);
let ct = testcase.encrypt(PKEV1_TEST_PARAMS);
type Curve = crate::curve_api::Bls12_446;
type Curve = curve_api::Bls12_446;
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -1526,4 +1569,148 @@ mod tests {
verify(&proof, (&public_param, &public_commit), &testcase.metadata).unwrap()
}
}
#[test]
fn test_proof_usable() {
let PkeTestParameters {
d,
k,
B,
q,
t,
msbs_zero_padding_bit_count,
} = PKEV1_TEST_PARAMS;
let rng = &mut StdRng::seed_from_u64(0);
let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS);
let ct = testcase.encrypt(PKEV1_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
let (public_commit, private_commit) = commit(
testcase.a.clone(),
testcase.b.clone(),
ct.c1.clone(),
ct.c2.clone(),
testcase.r.clone(),
testcase.e1.clone(),
testcase.m.clone(),
testcase.e2.clone(),
&public_param,
rng,
);
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let valid_proof = prove(
(&public_param, &public_commit),
&private_commit,
&testcase.metadata,
load,
rng,
);
let compressed_proof = bincode::serialize(&valid_proof.compress()).unwrap();
let proof_that_was_compressed: Proof<Curve> =
Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap();
assert!(valid_proof.is_usable());
assert!(proof_that_was_compressed.is_usable());
let not_on_curve_g1 = bls12_446::G1::projective(bls12_446::G1Affine {
inner: point_not_on_curve(rng),
});
let not_on_curve_g2 = bls12_446::G2::projective(bls12_446::G2Affine {
inner: point_not_on_curve(rng),
});
let not_in_group_g1 = bls12_446::G1::projective(bls12_446::G1Affine {
inner: point_on_curve_wrong_subgroup(rng),
});
let not_in_group_g2 = bls12_446::G2::projective(bls12_446::G2Affine {
inner: point_on_curve_wrong_subgroup(rng),
});
{
let mut proof = valid_proof.clone();
proof.c_hat = not_on_curve_g2;
assert!(!proof.is_usable());
proof.c_hat = not_in_group_g2;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.c_y = not_on_curve_g1;
assert!(!proof.is_usable());
proof.c_y = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.pi = not_on_curve_g1;
assert!(!proof.is_usable());
proof.pi = not_in_group_g1;
assert!(!proof.is_usable());
}
if let Some(ref valid_compute_proof_fields) = valid_proof.compute_load_proof_fields {
{
let mut proof = valid_proof.clone();
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
c_hat_t: not_on_curve_g2,
..valid_compute_proof_fields.clone()
});
assert!(!proof.is_usable());
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
c_hat_t: not_in_group_g2,
..valid_compute_proof_fields.clone()
});
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
c_h: not_on_curve_g1,
..valid_compute_proof_fields.clone()
});
assert!(!proof.is_usable());
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
c_h: not_in_group_g1,
..valid_compute_proof_fields.clone()
});
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
pi_kzg: not_on_curve_g1,
..valid_compute_proof_fields.clone()
});
assert!(!proof.is_usable());
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
pi_kzg: not_in_group_g1,
..valid_compute_proof_fields.clone()
});
assert!(!proof.is_usable());
}
}
}
}
}

View File

@@ -9,6 +9,7 @@ use crate::serialization::{
try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError,
SerializableGroupElements, SerializablePKEv2PublicParams,
};
use core::marker::PhantomData;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
@@ -221,6 +222,15 @@ impl<G: Curve> PublicParams<G> {
pub fn exclusive_max_noise(&self) -> u64 {
self.B
}
/// Check if the crs can be used to generate or verify a proof
///
/// This means checking that the points are:
/// - valid points of the curve
/// - in the correct subgroup
pub fn is_usable(&self) -> bool {
self.g_lists.is_valid()
}
}
/// This represents a proof that the given ciphertext is a valid encryptions of the input messages
@@ -247,6 +257,48 @@ pub struct Proof<G: Curve> {
pub(crate) compute_load_proof_fields: Option<ComputeLoadProofFields<G>>,
}
impl<G: Curve> Proof<G> {
/// Check if the proof can be used by the Verifier.
///
/// This means checking that the points in the proof are:
/// - valid points of the curve
/// - in the correct subgroup
pub fn is_usable(&self) -> bool {
let &Proof {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
ref compute_load_proof_fields,
} = self;
C_hat_e.validate_projective()
&& C_e.validate_projective()
&& C_r_tilde.validate_projective()
&& C_R.validate_projective()
&& C_hat_bin.validate_projective()
&& C_y.validate_projective()
&& C_h1.validate_projective()
&& C_h2.validate_projective()
&& C_hat_t.validate_projective()
&& pi.validate_projective()
&& pi_kzg.validate_projective()
&& compute_load_proof_fields.as_ref().map_or(
true,
|&ComputeLoadProofFields { C_hat_h3, C_hat_w }| {
C_hat_h3.validate_projective() && C_hat_w.validate_projective()
},
)
}
}
/// These fields can be pre-computed on the prover side in the faster Verifier scheme. If that's the
/// case, they should be included in the proof.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
@@ -2368,6 +2420,8 @@ pub fn verify<G: Curve>(
#[cfg(test)]
mod tests {
use crate::curve_api::{self, bls12_446};
use super::super::test::*;
use super::*;
use rand::rngs::StdRng;
@@ -2419,7 +2473,7 @@ mod tests {
let mut fake_metadata = [255u8; METADATA_LEN];
fake_metadata.fill_with(|| rng.gen::<u8>());
type Curve = crate::curve_api::Bls12_446;
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -2536,7 +2590,7 @@ mod tests {
let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
type Curve = crate::curve_api::Bls12_446;
type Curve = curve_api::Bls12_446;
// To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -2598,7 +2652,7 @@ mod tests {
let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
type Curve = crate::curve_api::Bls12_446;
type Curve = curve_api::Bls12_446;
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
@@ -2633,4 +2687,196 @@ mod tests {
verify(&proof, (&public_param, &public_commit), &testcase.metadata).unwrap()
}
}
#[test]
fn test_proof_usable() {
let PkeTestParameters {
d,
k,
B,
q,
t,
msbs_zero_padding_bit_count,
} = PKEV2_TEST_PARAMS;
let rng = &mut StdRng::seed_from_u64(0);
let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
type Curve = curve_api::Bls12_446;
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
let (public_commit, private_commit) = commit(
testcase.a.clone(),
testcase.b.clone(),
ct.c1.clone(),
ct.c2.clone(),
testcase.r.clone(),
testcase.e1.clone(),
testcase.m.clone(),
testcase.e2.clone(),
&public_param,
rng,
);
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let valid_proof = prove(
(&public_param, &public_commit),
&private_commit,
&testcase.metadata,
load,
rng,
);
let compressed_proof = bincode::serialize(&valid_proof.compress()).unwrap();
let proof_that_was_compressed: Proof<Curve> =
Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap();
assert!(valid_proof.is_usable());
assert!(proof_that_was_compressed.is_usable());
let not_on_curve_g1 = bls12_446::G1::projective(bls12_446::G1Affine {
inner: point_not_on_curve(rng),
});
let not_on_curve_g2 = bls12_446::G2::projective(bls12_446::G2Affine {
inner: point_not_on_curve(rng),
});
let not_in_group_g1 = bls12_446::G1::projective(bls12_446::G1Affine {
inner: point_on_curve_wrong_subgroup(rng),
});
let not_in_group_g2 = bls12_446::G2::projective(bls12_446::G2Affine {
inner: point_on_curve_wrong_subgroup(rng),
});
{
let mut proof = valid_proof.clone();
proof.C_hat_e = not_on_curve_g2;
assert!(!proof.is_usable());
proof.C_hat_e = not_in_group_g2;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_e = not_on_curve_g1;
assert!(!proof.is_usable());
proof.C_e = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_r_tilde = not_on_curve_g1;
assert!(!proof.is_usable());
proof.C_r_tilde = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_R = not_on_curve_g1;
assert!(!proof.is_usable());
proof.C_R = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_hat_bin = not_on_curve_g2;
assert!(!proof.is_usable());
proof.C_hat_bin = not_in_group_g2;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_y = not_on_curve_g1;
assert!(!proof.is_usable());
proof.C_y = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_h1 = not_on_curve_g1;
assert!(!proof.is_usable());
proof.C_h1 = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_h2 = not_on_curve_g1;
assert!(!proof.is_usable());
proof.C_h2 = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.C_hat_t = not_on_curve_g2;
assert!(!proof.is_usable());
proof.C_hat_t = not_in_group_g2;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.pi = not_on_curve_g1;
assert!(!proof.is_usable());
proof.pi = not_in_group_g1;
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.pi_kzg = not_on_curve_g1;
assert!(!proof.is_usable());
proof.pi_kzg = not_in_group_g1;
assert!(!proof.is_usable());
}
if let Some(ref valid_compute_proof_fields) = valid_proof.compute_load_proof_fields {
{
let mut proof = valid_proof.clone();
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
C_hat_h3: not_on_curve_g2,
C_hat_w: valid_compute_proof_fields.C_hat_w,
});
assert!(!proof.is_usable());
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
C_hat_h3: not_in_group_g2,
C_hat_w: valid_compute_proof_fields.C_hat_w,
});
assert!(!proof.is_usable());
}
{
let mut proof = valid_proof.clone();
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
C_hat_h3: valid_compute_proof_fields.C_hat_h3,
C_hat_w: not_on_curve_g2,
});
assert!(!proof.is_usable());
proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
C_hat_h3: valid_compute_proof_fields.C_hat_h3,
C_hat_w: not_in_group_g2,
});
assert!(!proof.is_usable());
}
}
}
}
}