feat(zk): add compact hash mode for zkv2

This commit is contained in:
Nicolas Sarlin
2025-06-18 18:29:30 +02:00
committed by Nicolas Sarlin
parent 215ded90c0
commit c475dc058e
3 changed files with 978 additions and 270 deletions

View File

@@ -1,11 +1,13 @@
// to follow the notation of the paper
#![allow(non_snake_case)]
use std::convert::Infallible;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use crate::curve_api::{CompressedG1, CompressedG2, Compressible, Curve};
use crate::proofs::pke_v2::{
CompressedComputeLoadProofFields, CompressedProof, ComputeLoadProofFields, Proof,
CompressedComputeLoadProofFields, CompressedProof, ComputeLoadProofFields, PkeV2HashMode, Proof,
};
use super::IncompleteProof;
@@ -28,37 +30,107 @@ pub struct ProofV0<G: Curve> {
C_hat_w: Option<G::G2>,
}
impl<G: Curve> Upgrade<Proof<G>> for ProofV0<G> {
impl<G: Curve> Upgrade<ProofV1<G>> for ProofV0<G> {
type Error = IncompleteProof;
fn upgrade(self) -> Result<Proof<G>, Self::Error> {
let compute_load_proof_fields = match (self.C_hat_h3, self.C_hat_w) {
fn upgrade(self) -> Result<ProofV1<G>, Self::Error> {
let ProofV0 {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
C_hat_h3,
C_hat_w,
} = self;
let compute_load_proof_fields = match (C_hat_h3, C_hat_w) {
(None, None) => None,
(Some(C_hat_h3), Some(C_hat_w)) => Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }),
_ => return Err(IncompleteProof),
};
Ok(Proof {
C_hat_e: self.C_hat_e,
C_e: self.C_e,
C_r_tilde: self.C_r_tilde,
C_R: self.C_R,
C_hat_bin: self.C_hat_bin,
C_y: self.C_y,
C_h1: self.C_h1,
C_h2: self.C_h2,
C_hat_t: self.C_hat_t,
pi: self.pi,
pi_kzg: self.pi_kzg,
Ok(ProofV1 {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
compute_load_proof_fields,
})
}
}
#[derive(Version)]
pub struct ProofV1<G: Curve> {
C_hat_e: G::G2,
C_e: G::G1,
C_r_tilde: G::G1,
C_R: G::G1,
C_hat_bin: G::G2,
C_y: G::G1,
C_h1: G::G1,
C_h2: G::G1,
C_hat_t: G::G2,
pi: G::G1,
pi_kzg: G::G1,
compute_load_proof_fields: Option<ComputeLoadProofFields<G>>,
}
impl<G: Curve> Upgrade<Proof<G>> for ProofV1<G> {
type Error = Infallible;
fn upgrade(self) -> Result<Proof<G>, Self::Error> {
let ProofV1 {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
compute_load_proof_fields,
} = self;
Ok(Proof {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
compute_load_proof_fields,
hash_mode: PkeV2HashMode::BackwardCompat,
})
}
}
#[derive(VersionsDispatch)]
pub enum ProofVersions<G: Curve> {
V0(ProofV0<G>),
V1(Proof<G>),
V1(ProofV1<G>),
V2(Proof<G>),
}
#[derive(VersionsDispatch)]
@@ -67,6 +139,7 @@ pub(crate) enum ComputeLoadProofFieldsVersions<G: Curve> {
V0(ComputeLoadProofFields<G>),
}
#[derive(Version)]
pub struct CompressedProofV0<G: Curve>
where
G::G1: Compressible,
@@ -88,15 +161,31 @@ where
C_hat_w: Option<CompressedG2<G>>,
}
impl<G: Curve> Upgrade<CompressedProof<G>> for CompressedProofV0<G>
impl<G: Curve> Upgrade<CompressedProofV1<G>> for CompressedProofV0<G>
where
G::G1: Compressible,
G::G2: Compressible,
{
type Error = IncompleteProof;
fn upgrade(self) -> Result<CompressedProof<G>, Self::Error> {
let compute_load_proof_fields = match (self.C_hat_h3, self.C_hat_w) {
fn upgrade(self) -> Result<CompressedProofV1<G>, Self::Error> {
let CompressedProofV0 {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
C_hat_h3,
C_hat_w,
} = self;
let compute_load_proof_fields = match (C_hat_h3, C_hat_w) {
(None, None) => None,
(Some(C_hat_h3), Some(C_hat_w)) => {
Some(CompressedComputeLoadProofFields { C_hat_h3, C_hat_w })
@@ -104,23 +193,84 @@ where
_ => return Err(IncompleteProof),
};
Ok(CompressedProof {
C_hat_e: self.C_hat_e,
C_e: self.C_e,
C_r_tilde: self.C_r_tilde,
C_R: self.C_R,
C_hat_bin: self.C_hat_bin,
C_y: self.C_y,
C_h1: self.C_h1,
C_h2: self.C_h2,
C_hat_t: self.C_hat_t,
pi: self.pi,
pi_kzg: self.pi_kzg,
Ok(CompressedProofV1 {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
compute_load_proof_fields,
})
}
}
#[derive(Version)]
pub struct CompressedProofV1<G: Curve>
where
G::G1: Compressible,
G::G2: Compressible,
{
C_hat_e: CompressedG2<G>,
C_e: CompressedG1<G>,
C_r_tilde: CompressedG1<G>,
C_R: CompressedG1<G>,
C_hat_bin: CompressedG2<G>,
C_y: CompressedG1<G>,
C_h1: CompressedG1<G>,
C_h2: CompressedG1<G>,
C_hat_t: CompressedG2<G>,
pi: CompressedG1<G>,
pi_kzg: CompressedG1<G>,
compute_load_proof_fields: Option<CompressedComputeLoadProofFields<G>>,
}
impl<G: Curve> Upgrade<CompressedProof<G>> for CompressedProofV1<G>
where
G::G1: Compressible,
G::G2: Compressible,
{
type Error = Infallible;
fn upgrade(self) -> Result<CompressedProof<G>, Self::Error> {
let CompressedProofV1 {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
compute_load_proof_fields,
} = self;
Ok(CompressedProof {
C_hat_e,
C_e,
C_r_tilde,
C_R,
C_hat_bin,
C_y,
C_h1,
C_h2,
C_hat_t,
pi,
pi_kzg,
compute_load_proof_fields,
hash_mode: PkeV2HashMode::BackwardCompat,
})
}
}
#[derive(VersionsDispatch)]
pub enum CompressedProofVersions<G: Curve>
where
@@ -140,3 +290,9 @@ where
#[allow(dead_code)]
V0(CompressedComputeLoadProofFields<G>),
}
#[derive(VersionsDispatch)]
pub(crate) enum PkeV2HashModeVersions {
#[allow(dead_code)]
V0(PkeV2HashMode),
}

File diff suppressed because it is too large Load Diff

View File

@@ -12,6 +12,7 @@ use crate::serialization::{
};
use core::marker::PhantomData;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
@@ -19,6 +20,8 @@ mod hashes;
use hashes::RHash;
pub(crate) use hashes::PkeV2HashMode;
fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
}
@@ -366,8 +369,8 @@ pub struct Proof<G: Curve> {
pub(crate) C_hat_t: G::G2,
pub(crate) pi: G::G1,
pub(crate) pi_kzg: G::G1,
pub(crate) compute_load_proof_fields: Option<ComputeLoadProofFields<G>>,
pub(crate) hash_mode: PkeV2HashMode,
}
impl<G: Curve> Proof<G> {
@@ -390,6 +393,7 @@ impl<G: Curve> Proof<G> {
pi,
pi_kzg,
ref compute_load_proof_fields,
hash_mode: _,
} = self;
C_hat_e.validate_projective()
@@ -455,8 +459,8 @@ where
pub(crate) C_hat_t: CompressedG2<G>,
pub(crate) pi: CompressedG1<G>,
pub(crate) pi_kzg: CompressedG1<G>,
pub(crate) compute_load_proof_fields: Option<CompressedComputeLoadProofFields<G>>,
pub(crate) hash_mode: PkeV2HashMode,
}
#[derive(Serialize, Deserialize, Versionize)]
@@ -497,6 +501,7 @@ where
pi,
pi_kzg,
compute_load_proof_fields,
hash_mode,
} = self;
CompressedProof {
@@ -518,6 +523,7 @@ where
C_hat_w: C_hat_w.compress(),
},
),
hash_mode: *hash_mode,
}
}
@@ -535,6 +541,7 @@ where
pi,
pi_kzg,
compute_load_proof_fields,
hash_mode,
} = compressed;
Ok(Proof {
@@ -562,6 +569,7 @@ where
} else {
None
},
hash_mode,
})
}
}
@@ -812,6 +820,7 @@ pub fn prove<G: Curve>(
metadata,
load,
seed,
PkeV2HashMode::Compact,
ProofSanityCheckMode::Panic,
)
}
@@ -822,6 +831,7 @@ fn prove_impl<G: Curve>(
metadata: &[u8],
load: ComputeLoad,
seed: &[u8],
hash_mode: PkeV2HashMode,
sanity_check_mode: ProofSanityCheckMode,
) -> Proof<G> {
_ = load;
@@ -966,6 +976,7 @@ fn prove_impl<G: Curve>(
C_hat_e_bytes.as_ref(),
C_e_bytes.as_ref(),
C_r_tilde_bytes.as_ref(),
hash_mode,
);
let R = |i: usize, j: usize| R[i + j * 128];
@@ -1430,7 +1441,7 @@ fn prove_impl<G: Curve>(
ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + n],
ComputeLoad::Verify => vec![],
};
let mut P_w = match load {
let mut P_omega = match load {
ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + d + k + 4],
ComputeLoad::Verify => vec![],
};
@@ -1502,15 +1513,15 @@ fn prove_impl<G: Curve>(
}
}
if !P_w.is_empty() {
P_w[1..].copy_from_slice(&omega[..d + k + 4]);
if !P_omega.is_empty() {
P_omega[1..].copy_from_slice(&omega[..d + k + 4]);
}
let mut p_h1 = G::Zp::ZERO;
let mut p_h2 = G::Zp::ZERO;
let mut p_t = G::Zp::ZERO;
let mut p_h3 = G::Zp::ZERO;
let mut p_w = G::Zp::ZERO;
let mut p_omega = G::Zp::ZERO;
let mut pow = G::Zp::ONE;
for j in 0..n + 1 {
@@ -1521,14 +1532,21 @@ fn prove_impl<G: Curve>(
if j < P_h3.len() {
p_h3 += P_h3[j] * pow;
}
if j < P_w.len() {
p_w += P_w[j] * pow;
if j < P_omega.len() {
p_omega += P_omega[j] * pow;
}
pow = pow * z;
}
let chi = z_hash.gen_chi(p_h1, p_h2, p_t);
let p_h3_opt = if P_h3.is_empty() { None } else { Some(p_h3) };
let p_omega_opt = if P_omega.is_empty() {
None
} else {
Some(p_omega)
};
let chi = z_hash.gen_chi(p_h1, p_h2, p_t, p_h3_opt, p_omega_opt);
let mut Q_kzg = vec![G::Zp::ZERO; 1 + n];
let chi2 = chi * chi;
@@ -1539,11 +1557,11 @@ fn prove_impl<G: Curve>(
if j < P_h3.len() {
Q_kzg[j] += chi3 * P_h3[j];
}
if j < P_w.len() {
Q_kzg[j] += chi4 * P_w[j];
if j < P_omega.len() {
Q_kzg[j] += chi4 * P_omega[j];
}
}
Q_kzg[0] -= p_h1 + chi * p_h2 + chi2 * p_t + chi3 * p_h3 + chi4 * p_w;
Q_kzg[0] -= p_h1 + chi * p_h2 + chi2 * p_t + chi3 * p_h3 + chi4 * p_omega;
// https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode
let mut q = vec![G::Zp::ZERO; n];
@@ -1568,6 +1586,7 @@ fn prove_impl<G: Curve>(
pi,
pi_kzg,
compute_load_proof_fields,
hash_mode,
}
}
@@ -1703,6 +1722,7 @@ pub fn verify_impl<G: Curve>(
pi,
pi_kzg,
ref compute_load_proof_fields,
hash_mode,
} = proof;
let pairing = G::Gt::pairing;
@@ -1766,6 +1786,7 @@ pub fn verify_impl<G: Curve>(
C_hat_e_bytes.as_ref(),
C_e_bytes.as_ref(),
C_r_tilde_bytes.as_ref(),
hash_mode,
);
let R = |i: usize, j: usize| R[i + j * 128];
@@ -1896,7 +1917,7 @@ pub fn verify_impl<G: Curve>(
ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + n],
ComputeLoad::Verify => vec![],
};
let mut P_w = match load {
let mut P_omega = match load {
ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + d + k + 4],
ComputeLoad::Verify => vec![],
};
@@ -1968,15 +1989,15 @@ pub fn verify_impl<G: Curve>(
}
}
if !P_w.is_empty() {
P_w[1..].copy_from_slice(&omega[..d + k + 4]);
if !P_omega.is_empty() {
P_omega[1..].copy_from_slice(&omega[..d + k + 4]);
}
let mut p_h1 = G::Zp::ZERO;
let mut p_h2 = G::Zp::ZERO;
let mut p_t = G::Zp::ZERO;
let mut p_h3 = G::Zp::ZERO;
let mut p_w = G::Zp::ZERO;
let mut p_omega = G::Zp::ZERO;
let mut pow = G::Zp::ONE;
for j in 0..n + 1 {
@@ -1987,14 +2008,21 @@ pub fn verify_impl<G: Curve>(
if j < P_h3.len() {
p_h3 += P_h3[j] * pow;
}
if j < P_w.len() {
p_w += P_w[j] * pow;
if j < P_omega.len() {
p_omega += P_omega[j] * pow;
}
pow = pow * z;
}
let chi = z_hash.gen_chi(p_h1, p_h2, p_t);
let p_h3_opt = if P_h3.is_empty() { None } else { Some(p_h3) };
let p_omega_opt = if P_omega.is_empty() {
None
} else {
Some(p_omega)
};
let chi = z_hash.gen_chi(p_h1, p_h2, p_t, p_h3_opt, p_omega_opt);
let chi2 = chi * chi;
let chi3 = chi2 * chi;
@@ -2012,7 +2040,7 @@ pub fn verify_impl<G: Curve>(
C_hat += C_hat_w.mul_scalar(chi4);
}
C_hat
} - g_hat.mul_scalar(p_t * chi2 + p_h3 * chi3 + p_w * chi4),
} - g_hat.mul_scalar(p_t * chi2 + p_h3 * chi3 + p_omega * chi4),
);
let rhs = pairing(
pi_kzg,
@@ -2200,12 +2228,122 @@ mod tests {
}
}
#[test]
fn test_pke_legacy_hash() {
let PkeTestParameters {
d,
k,
B,
q,
t,
msbs_zero_padding_bit_count,
} = PKEV2_TEST_PARAMS;
let effective_cleartext_t = t >> msbs_zero_padding_bit_count;
let seed = thread_rng().gen();
println!("pkev2 legacy hash seed: {seed:x}");
let rng = &mut StdRng::seed_from_u64(seed);
let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
let fake_e1 = (0..d)
.map(|_| (rng.gen::<u64>() % (2 * B)) as i64 - B as i64)
.collect::<Vec<_>>();
let fake_e2 = (0..k)
.map(|_| (rng.gen::<u64>() % (2 * B)) as i64 - B as i64)
.collect::<Vec<_>>();
let fake_r = (0..d)
.map(|_| (rng.gen::<u64>() % 2) as i64)
.collect::<Vec<_>>();
let fake_m = (0..k)
.map(|_| (rng.gen::<u64>() % effective_cleartext_t) as i64)
.collect::<Vec<_>>();
let mut fake_metadata = [255u8; METADATA_LEN];
fake_metadata.fill_with(|| rng.gen::<u8>());
// To check management of bigger k_max from CRS during test
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
for (use_fake_r, use_fake_e1, use_fake_e2, use_fake_m, use_fake_metadata_verify) in itertools::iproduct!(
[false, true],
[false, true],
[false, true],
[false, true],
[false, true]
) {
let (public_commit, private_commit) = commit(
testcase.a.clone(),
testcase.b.clone(),
ct.c1.clone(),
ct.c2.clone(),
if use_fake_r {
fake_r.clone()
} else {
testcase.r.clone()
},
if use_fake_e1 {
fake_e1.clone()
} else {
testcase.e1.clone()
},
if use_fake_m {
fake_m.clone()
} else {
testcase.m.clone()
},
if use_fake_e2 {
fake_e2.clone()
} else {
testcase.e2.clone()
},
&public_param,
);
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
for hash_mode in [PkeV2HashMode::BackwardCompat, PkeV2HashMode::Classical] {
let proof = prove_impl(
(&public_param, &public_commit),
&private_commit,
&testcase.metadata,
load,
&seed.to_le_bytes(),
hash_mode,
ProofSanityCheckMode::Panic,
);
let verify_metadata = if use_fake_metadata_verify {
&fake_metadata
} else {
&testcase.metadata
};
assert_eq!(
verify(&proof, (&public_param, &public_commit), verify_metadata).is_err(),
use_fake_e1
|| use_fake_e2
|| use_fake_r
|| use_fake_m
|| use_fake_metadata_verify
);
}
}
}
}
fn prove_and_verify(
testcase: &PkeTestcase,
ct: &PkeTestCiphertext,
crs: &PublicParams<Curve>,
load: ComputeLoad,
seed: &[u8],
hash_mode: PkeV2HashMode,
sanity_check_mode: ProofSanityCheckMode,
) -> VerificationResult {
let (public_commit, private_commit) = commit(
@@ -2226,6 +2364,7 @@ mod tests {
&testcase.metadata,
load,
seed,
hash_mode,
sanity_check_mode,
);
@@ -2236,20 +2375,22 @@ mod tests {
}
}
#[allow(clippy::too_many_arguments)]
fn assert_prove_and_verify(
testcase: &PkeTestcase,
ct: &PkeTestCiphertext,
testcase_name: &str,
crs: &PublicParams<Curve>,
seed: &[u8],
hash_mode: PkeV2HashMode,
sanity_check_mode: ProofSanityCheckMode,
expected_result: VerificationResult,
) {
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
assert_eq!(
prove_and_verify(testcase, ct, crs, load, seed, sanity_check_mode),
prove_and_verify(testcase, ct, crs, load, seed, hash_mode, sanity_check_mode),
expected_result,
"Testcase {testcase_name} with load {load} failed"
"Testcase {testcase_name} {hash_mode:?} hash with load {load} failed"
)
}
}
@@ -2515,6 +2656,7 @@ mod tests {
&format!("{name}_crs"),
&crs,
&seed.to_le_bytes(),
PkeV2HashMode::Compact,
ProofSanityCheckMode::Ignore,
expected_result,
);
@@ -2524,6 +2666,7 @@ mod tests {
&format!("{name}_crs_max_k"),
&crs_max_k,
&seed.to_le_bytes(),
PkeV2HashMode::Compact,
ProofSanityCheckMode::Ignore,
expected_result,
);
@@ -2617,6 +2760,7 @@ mod tests {
test_name,
&public_param,
&seed.to_le_bytes(),
PkeV2HashMode::Compact,
ProofSanityCheckMode::Panic,
VerificationResult::Reject,
);
@@ -2789,6 +2933,7 @@ mod tests {
"testcase_bad_delta",
&crs,
&seed.to_le_bytes(),
PkeV2HashMode::Compact,
ProofSanityCheckMode::Panic,
VerificationResult::Reject,
);
@@ -2830,6 +2975,7 @@ mod tests {
&format!("testcase_big_params_{bound:?}"),
&crs,
&seed.to_le_bytes(),
PkeV2HashMode::Compact,
ProofSanityCheckMode::Panic,
VerificationResult::Accept,
);