diff --git a/tfhe-zk-pok/src/backward_compatibility/mod.rs b/tfhe-zk-pok/src/backward_compatibility/mod.rs index 91753ea16..c54843549 100644 --- a/tfhe-zk-pok/src/backward_compatibility/mod.rs +++ b/tfhe-zk-pok/src/backward_compatibility/mod.rs @@ -1,8 +1,12 @@ +pub mod pke; +pub mod pke_v2; + +use std::error::Error; +use std::fmt::Display; + use tfhe_versionable::VersionsDispatch; -use crate::curve_api::{Compressible, Curve}; -use crate::proofs::pke::{CompressedProof as PKEv1CompressedProof, Proof as PKEv1Proof}; -use crate::proofs::pke_v2::{CompressedProof as PKEv2CompressedProof, Proof as PKEv2Proof}; +use crate::curve_api::Curve; use crate::proofs::GroupElements; use crate::serialization::{ SerializableAffine, SerializableCubicExtField, SerializableFp, SerializableFp2, @@ -34,33 +38,20 @@ pub type SerializableG1AffineVersions = SerializableAffineVersions; pub type SerializableFp12Versions = SerializableQuadExtFieldVersions; -#[derive(VersionsDispatch)] -pub enum PKEv1ProofVersions { - V0(PKEv1Proof), +/// The proof was missing some elements +#[derive(Debug)] +pub struct IncompleteProof; + +impl Display for IncompleteProof { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "incomplete serialized ZK proof, missing some pre-computed elements" + ) + } } -#[derive(VersionsDispatch)] -pub enum PKEv2ProofVersions { - V0(PKEv2Proof), -} - -#[derive(VersionsDispatch)] -pub enum PKEv1CompressedProofVersions -where - G::G1: Compressible, - G::G2: Compressible, -{ - V0(PKEv1CompressedProof), -} - -#[derive(VersionsDispatch)] -pub enum PKEv2CompressedProofVersions -where - G::G1: Compressible, - G::G2: Compressible, -{ - V0(PKEv2CompressedProof), -} +impl Error for IncompleteProof {} #[derive(VersionsDispatch)] #[allow(dead_code)] diff --git a/tfhe-zk-pok/src/backward_compatibility/pke.rs b/tfhe-zk-pok/src/backward_compatibility/pke.rs new file mode 100644 index 000000000..c84a02647 --- /dev/null +++ b/tfhe-zk-pok/src/backward_compatibility/pke.rs @@ -0,0 +1,117 @@ +use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; + +use crate::curve_api::{CompressedG1, CompressedG2, Compressible, Curve}; +use crate::proofs::pke::{ + CompressedComputeLoadProofFields, CompressedProof, ComputeLoadProofFields, Proof, +}; + +use super::IncompleteProof; + +#[derive(Version)] +pub struct ProofV0 { + c_hat: G::G2, + c_y: G::G1, + pi: G::G1, + c_hat_t: Option, + c_h: Option, + pi_kzg: Option, +} + +impl Upgrade> for ProofV0 { + type Error = IncompleteProof; + + fn upgrade(self) -> Result, Self::Error> { + let compute_load_proof_fields = match (self.c_hat_t, self.c_h, self.pi_kzg) { + (None, None, None) => None, + (Some(c_hat_t), Some(c_h), Some(pi_kzg)) => Some(ComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }), + _ => { + return Err(IncompleteProof); + } + }; + + Ok(Proof { + c_hat: self.c_hat, + c_y: self.c_y, + pi: self.pi, + compute_load_proof_fields, + }) + } +} + +#[derive(VersionsDispatch)] +pub enum ProofVersions { + V0(ProofV0), + V1(Proof), +} + +#[derive(VersionsDispatch)] +#[allow(dead_code)] +pub(crate) enum ComputeLoadProofFieldVersions { + V0(ComputeLoadProofFields), +} + +pub struct CompressedProofV0 +where + G::G1: Compressible, + G::G2: Compressible, +{ + c_hat: CompressedG2, + c_y: CompressedG1, + pi: CompressedG1, + c_hat_t: Option>, + c_h: Option>, + pi_kzg: Option>, +} + +impl Upgrade> for CompressedProofV0 +where + G::G1: Compressible, + G::G2: Compressible, +{ + type Error = IncompleteProof; + + fn upgrade(self) -> Result, Self::Error> { + let compute_load_proof_fields = match (self.c_hat_t, self.c_h, self.pi_kzg) { + (None, None, None) => None, + (Some(c_hat_t), Some(c_h), Some(pi_kzg)) => Some(CompressedComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }), + _ => { + return Err(IncompleteProof); + } + }; + + Ok(CompressedProof { + c_hat: self.c_hat, + c_y: self.c_y, + pi: self.pi, + compute_load_proof_fields, + }) + } +} + +#[derive(VersionsDispatch)] +pub enum CompressedProofVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(CompressedProofV0), + V1(CompressedProof), +} + +#[derive(VersionsDispatch)] +#[allow(dead_code)] +pub(crate) enum CompressedComputeLoadProofFieldsVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(CompressedComputeLoadProofFields), +} diff --git a/tfhe-zk-pok/src/backward_compatibility/pke_v2.rs b/tfhe-zk-pok/src/backward_compatibility/pke_v2.rs new file mode 100644 index 000000000..43af70559 --- /dev/null +++ b/tfhe-zk-pok/src/backward_compatibility/pke_v2.rs @@ -0,0 +1,142 @@ +// to follow the notation of the paper +#![allow(non_snake_case)] + +use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; + +use crate::curve_api::{CompressedG1, CompressedG2, Compressible, Curve}; +use crate::proofs::pke_v2::{ + CompressedComputeLoadProofFields, CompressedProof, ComputeLoadProofFields, Proof, +}; + +use super::IncompleteProof; + +#[derive(Version)] +pub struct ProofV0 { + C_hat_e: G::G2, + C_e: G::G1, + C_r_tilde: G::G1, + C_R: G::G1, + C_hat_bin: G::G2, + C_y: G::G1, + C_h1: G::G1, + C_h2: G::G1, + C_hat_t: G::G2, + pi: G::G1, + pi_kzg: G::G1, + + C_hat_h3: Option, + C_hat_w: Option, +} + +impl Upgrade> for ProofV0 { + type Error = IncompleteProof; + + fn upgrade(self) -> Result, Self::Error> { + let compute_load_proof_fields = match (self.C_hat_h3, self.C_hat_w) { + (None, None) => None, + (Some(C_hat_h3), Some(C_hat_w)) => Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }), + _ => return Err(IncompleteProof), + }; + + Ok(Proof { + C_hat_e: self.C_hat_e, + C_e: self.C_e, + C_r_tilde: self.C_r_tilde, + C_R: self.C_R, + C_hat_bin: self.C_hat_bin, + C_y: self.C_y, + C_h1: self.C_h1, + C_h2: self.C_h2, + C_hat_t: self.C_hat_t, + pi: self.pi, + pi_kzg: self.pi_kzg, + compute_load_proof_fields, + }) + } +} + +#[derive(VersionsDispatch)] +pub enum ProofVersions { + V0(ProofV0), + V1(Proof), +} + +#[derive(VersionsDispatch)] +#[allow(dead_code)] +pub(crate) enum ComputeLoadProofFieldVersions { + V0(ComputeLoadProofFields), +} + +pub struct CompressedProofV0 +where + G::G1: Compressible, + G::G2: Compressible, +{ + C_hat_e: CompressedG2, + C_e: CompressedG1, + C_r_tilde: CompressedG1, + C_R: CompressedG1, + C_hat_bin: CompressedG2, + C_y: CompressedG1, + C_h1: CompressedG1, + C_h2: CompressedG1, + C_hat_t: CompressedG2, + pi: CompressedG1, + pi_kzg: CompressedG1, + + C_hat_h3: Option>, + C_hat_w: Option>, +} + +impl Upgrade> for CompressedProofV0 +where + G::G1: Compressible, + G::G2: Compressible, +{ + type Error = IncompleteProof; + + fn upgrade(self) -> Result, Self::Error> { + let compute_load_proof_fields = match (self.C_hat_h3, self.C_hat_w) { + (None, None) => None, + (Some(C_hat_h3), Some(C_hat_w)) => { + Some(CompressedComputeLoadProofFields { C_hat_h3, C_hat_w }) + } + _ => return Err(IncompleteProof), + }; + + Ok(CompressedProof { + C_hat_e: self.C_hat_e, + C_e: self.C_e, + C_r_tilde: self.C_r_tilde, + C_R: self.C_R, + C_hat_bin: self.C_hat_bin, + C_y: self.C_y, + C_h1: self.C_h1, + C_h2: self.C_h2, + C_hat_t: self.C_hat_t, + pi: self.pi, + pi_kzg: self.pi_kzg, + compute_load_proof_fields, + }) + } +} + +#[derive(VersionsDispatch)] +pub enum CompressedProofVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(CompressedProofV0), + V1(CompressedProof), +} + +#[derive(VersionsDispatch)] +#[allow(dead_code)] +pub(crate) enum CompressedComputeLoadProofFieldsVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(CompressedComputeLoadProofFields), +} diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs index 8bf60f37e..af7ef26cc 100644 --- a/tfhe-zk-pok/src/curve_api.rs +++ b/tfhe-zk-pok/src/curve_api.rs @@ -128,6 +128,9 @@ pub trait Compressible: Sized { fn uncompress(compressed: Self::Compressed) -> Result; } +pub type CompressedG1 = <::G1 as Compressible>::Compressed; +pub type CompressedG2 = <::G2 as Compressible>::Compressed; + pub trait PairingGroupOps: Copy + Send diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index 1c078c3cc..0bd53ad41 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -1,6 +1,9 @@ // TODO: refactor copy-pasted code in proof/verify -use crate::backward_compatibility::{PKEv1CompressedProofVersions, PKEv1ProofVersions}; +use crate::backward_compatibility::pke::{ + CompressedComputeLoadProofFieldsVersions, CompressedProofVersions, + ComputeLoadProofFieldVersions, ProofVersions, +}; use crate::serialization::{ try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError, SerializableGroupElements, SerializablePKEv1PublicParams, @@ -186,14 +189,26 @@ impl PublicParams { deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" ))] -#[versionize(PKEv1ProofVersions)] +#[versionize(ProofVersions)] pub struct Proof { - c_hat: G::G2, - c_y: G::G1, - pi: G::G1, - c_hat_t: Option, - c_h: Option, - pi_kzg: Option, + pub(crate) c_hat: G::G2, + pub(crate) c_y: G::G1, + pub(crate) pi: G::G1, + pub(crate) compute_load_proof_fields: Option>, +} + +/// These fields can be pre-computed on the prover side in the faster Verifier scheme. If that's the +/// case, they should be included in the proof. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] +#[serde(bound( + deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", + serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" +))] +#[versionize(ComputeLoadProofFieldVersions)] +pub(crate) struct ComputeLoadProofFields { + pub(crate) c_hat_t: G::G2, + pub(crate) c_h: G::G1, + pub(crate) pi_kzg: G::G1, } type CompressedG2 = <::G2 as Compressible>::Compressed; @@ -204,18 +219,32 @@ type CompressedG1 = <::G1 as Compressible>::Compressed; deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" ))] -#[versionize(PKEv1CompressedProofVersions)] +#[versionize(CompressedProofVersions)] pub struct CompressedProof where G::G1: Compressible, G::G2: Compressible, { - c_hat: CompressedG2, - c_y: CompressedG1, - pi: CompressedG1, - c_hat_t: Option>, - c_h: Option>, - pi_kzg: Option>, + pub(crate) c_hat: CompressedG2, + pub(crate) c_y: CompressedG1, + pub(crate) pi: CompressedG1, + pub(crate) compute_load_proof_fields: Option>, +} + +#[derive(Serialize, Deserialize, Versionize)] +#[serde(bound( + deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", + serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" +))] +#[versionize(CompressedComputeLoadProofFieldsVersions)] +pub(crate) struct CompressedComputeLoadProofFields +where + G::G1: Compressible, + G::G2: Compressible, +{ + pub(crate) c_hat_t: CompressedG2, + pub(crate) c_h: CompressedG1, + pub(crate) pi_kzg: CompressedG1, } impl Compressible for Proof @@ -232,18 +261,24 @@ where c_hat, c_y, pi, - c_hat_t, - c_h, - pi_kzg, + compute_load_proof_fields, } = self; CompressedProof { c_hat: c_hat.compress(), c_y: c_y.compress(), pi: pi.compress(), - c_hat_t: c_hat_t.map(|val| val.compress()), - c_h: c_h.map(|val| val.compress()), - pi_kzg: pi_kzg.map(|val| val.compress()), + compute_load_proof_fields: compute_load_proof_fields.as_ref().map( + |ComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }| CompressedComputeLoadProofFields { + c_hat_t: c_hat_t.compress(), + c_h: c_h.compress(), + pi_kzg: pi_kzg.compress(), + }, + ), } } @@ -252,28 +287,29 @@ where c_hat, c_y, pi, - c_hat_t, - c_h, - pi_kzg, + compute_load_proof_fields, } = compressed; Ok(Proof { c_hat: G::G2::uncompress(c_hat)?, c_y: G::G1::uncompress(c_y)?, pi: G::G1::uncompress(pi)?, - c_hat_t: c_hat_t.map(G::G2::uncompress).transpose()?, - c_h: c_h.map(G::G1::uncompress).transpose()?, - pi_kzg: pi_kzg.map(G::G1::uncompress).transpose()?, - }) - } -} -impl Proof { - pub fn content_is_usable(&self) -> bool { - matches!( - (self.c_hat_t, self.c_h, self.pi_kzg), - (None, None, None) | (Some(_), Some(_), Some(_)) - ) + compute_load_proof_fields: if let Some(CompressedComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }) = compute_load_proof_fields + { + Some(ComputeLoadProofFields { + c_hat_t: G::G2::uncompress(c_hat_t)?, + c_h: G::G1::uncompress(c_h)?, + pi_kzg: G::G1::uncompress(pi_kzg)?, + }) + } else { + None + }, + }) } } @@ -793,18 +829,18 @@ pub fn prove( c_hat, c_y, pi, - c_hat_t: Some(c_hat_t), - c_h: Some(c_h), - pi_kzg: Some(pi_kzg), + compute_load_proof_fields: Some(ComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }), } } else { Proof { c_hat, c_y, pi, - c_hat_t: None, - c_h: None, - pi_kzg: None, + compute_load_proof_fields: None, } } } @@ -939,10 +975,9 @@ pub fn verify( c_hat, c_y, pi, - c_hat_t, - c_h, - pi_kzg, + ref compute_load_proof_fields, } = proof; + let e = G::Gt::pairing; let &PublicParams { @@ -1081,7 +1116,12 @@ pub fn verify( let [delta_eq, delta_y] = delta; let delta = [delta_eq, delta_y, delta_theta]; - if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) { + if let Some(&ComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }) = compute_load_proof_fields.as_ref() + { let mut z = G::Zp::ZERO; G::Zp::hash( core::array::from_mut(&mut z), diff --git a/tfhe-zk-pok/src/proofs/pke_v2.rs b/tfhe-zk-pok/src/proofs/pke_v2.rs index e79bdcafe..77a56975a 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2.rs @@ -2,7 +2,8 @@ #![allow(non_snake_case)] use super::*; -use crate::backward_compatibility::{PKEv2CompressedProofVersions, PKEv2ProofVersions}; +use crate::backward_compatibility::pke_v2::{CompressedProofVersions, ProofVersions}; +use crate::curve_api::{CompressedG1, CompressedG2}; use crate::four_squares::*; use crate::serialization::{ try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError, @@ -229,52 +230,69 @@ impl PublicParams { deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" ))] -#[versionize(PKEv2ProofVersions)] +#[versionize(ProofVersions)] pub struct Proof { - C_hat_e: G::G2, - C_e: G::G1, - C_r_tilde: G::G1, - C_R: G::G1, - C_hat_bin: G::G2, - C_y: G::G1, - C_h1: G::G1, - C_h2: G::G1, - C_hat_t: G::G2, - pi: G::G1, - pi_kzg: G::G1, + pub(crate) C_hat_e: G::G2, + pub(crate) C_e: G::G1, + pub(crate) C_r_tilde: G::G1, + pub(crate) C_R: G::G1, + pub(crate) C_hat_bin: G::G2, + pub(crate) C_y: G::G1, + pub(crate) C_h1: G::G1, + pub(crate) C_h2: G::G1, + pub(crate) C_hat_t: G::G2, + pub(crate) pi: G::G1, + pub(crate) pi_kzg: G::G1, - C_hat_h3: Option, - C_hat_w: Option, + pub(crate) compute_load_proof_fields: Option>, } -type CompressedG2 = <::G2 as Compressible>::Compressed; -type CompressedG1 = <::G1 as Compressible>::Compressed; +/// These fields can be pre-computed on the prover side in the faster Verifier scheme. If that's the +/// case, they should be included in the proof. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub(crate) struct ComputeLoadProofFields { + pub(crate) C_hat_h3: G::G2, + pub(crate) C_hat_w: G::G2, +} #[derive(Serialize, Deserialize, Versionize)] #[serde(bound( deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" ))] -#[versionize(PKEv2CompressedProofVersions)] +#[versionize(CompressedProofVersions)] pub struct CompressedProof where G::G1: Compressible, G::G2: Compressible, { - C_hat_e: CompressedG2, - C_e: CompressedG1, - C_r_tilde: CompressedG1, - C_R: CompressedG1, - C_hat_bin: CompressedG2, - C_y: CompressedG1, - C_h1: CompressedG1, - C_h2: CompressedG1, - C_hat_t: CompressedG2, - pi: CompressedG1, - pi_kzg: CompressedG1, + pub(crate) C_hat_e: CompressedG2, + pub(crate) C_e: CompressedG1, + pub(crate) C_r_tilde: CompressedG1, + pub(crate) C_R: CompressedG1, + pub(crate) C_hat_bin: CompressedG2, + pub(crate) C_y: CompressedG1, + pub(crate) C_h1: CompressedG1, + pub(crate) C_h2: CompressedG1, + pub(crate) C_hat_t: CompressedG2, + pub(crate) pi: CompressedG1, + pub(crate) pi_kzg: CompressedG1, - C_hat_h3: Option>, - C_hat_w: Option>, + pub(crate) compute_load_proof_fields: Option>, +} + +#[derive(Serialize, Deserialize)] +#[serde(bound( + deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", + serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" +))] +pub(crate) struct CompressedComputeLoadProofFields +where + G::G1: Compressible, + G::G2: Compressible, +{ + pub(crate) C_hat_h3: CompressedG2, + pub(crate) C_hat_w: CompressedG2, } impl Compressible for Proof @@ -299,8 +317,7 @@ where C_hat_t, pi, pi_kzg, - C_hat_h3, - C_hat_w, + compute_load_proof_fields, } = self; CompressedProof { @@ -315,8 +332,13 @@ where C_hat_t: C_hat_t.compress(), pi: pi.compress(), pi_kzg: pi_kzg.compress(), - C_hat_h3: C_hat_h3.map(|val| val.compress()), - C_hat_w: C_hat_w.map(|val| val.compress()), + + compute_load_proof_fields: compute_load_proof_fields.as_ref().map( + |ComputeLoadProofFields { C_hat_h3, C_hat_w }| CompressedComputeLoadProofFields { + C_hat_h3: C_hat_h3.compress(), + C_hat_w: C_hat_w.compress(), + }, + ), } } @@ -333,8 +355,7 @@ where C_hat_t, pi, pi_kzg, - C_hat_h3, - C_hat_w, + compute_load_proof_fields, } = compressed; Ok(Proof { @@ -349,8 +370,19 @@ where C_hat_t: G::G2::uncompress(C_hat_t)?, pi: G::G1::uncompress(pi)?, pi_kzg: G::G1::uncompress(pi_kzg)?, - C_hat_h3: C_hat_h3.map(G::G2::uncompress).transpose()?, - C_hat_w: C_hat_w.map(G::G2::uncompress).transpose()?, + + compute_load_proof_fields: if let Some(CompressedComputeLoadProofFields { + C_hat_h3, + C_hat_w, + }) = compute_load_proof_fields + { + Some(ComputeLoadProofFields { + C_hat_h3: G::G2::uncompress(C_hat_h3)?, + C_hat_w: G::G2::uncompress(C_hat_w)?, + }) + } else { + None + }, }) } } @@ -1383,42 +1415,50 @@ pub fn prove( .collect::>(); scalars.reverse(); let C_h2 = G::G1::multi_mul_scalar(&g_list[..n], &scalars); - let (C_hat_h3, C_hat_w) = match load { - ComputeLoad::Proof => rayon::join( - || { - Some(G::G2::multi_mul_scalar( - &g_hat_list[n - (d + k)..n], - &(0..d + k) - .rev() - .map(|j| { - let mut acc = G::Zp::ZERO; - for (i, &phi) in phi.iter().enumerate() { - match R(i, d + k + 4 + j) { - 0 => {} - 1 => acc += phi, - -1 => acc -= phi, - _ => unreachable!(), + let compute_load_proof_fields = match load { + ComputeLoad::Proof => { + let (C_hat_h3, C_hat_w) = rayon::join( + || { + G::G2::multi_mul_scalar( + &g_hat_list[n - (d + k)..n], + &(0..d + k) + .rev() + .map(|j| { + let mut acc = G::Zp::ZERO; + for (i, &phi) in phi.iter().enumerate() { + match R(i, d + k + 4 + j) { + 0 => {} + 1 => acc += phi, + -1 => acc -= phi, + _ => unreachable!(), + } } - } - delta_r * acc - delta_theta_q * theta[j] - }) - .collect::>(), - )) - }, - || { - Some(G::G2::multi_mul_scalar( - &g_hat_list[..d + k + 4], - &w[..d + k + 4], - )) - }, - ), - ComputeLoad::Verify => (None, None), + delta_r * acc - delta_theta_q * theta[j] + }) + .collect::>(), + ) + }, + || G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &w[..d + k + 4]), + ); + + Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }) + } + ComputeLoad::Verify => None, }; - let C_hat_h3_bytes = C_hat_h3.map(G::G2::to_le_bytes); - let C_hat_w_bytes = C_hat_w.map(G::G2::to_le_bytes); - let C_hat_h3_bytes = C_hat_h3_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); - let C_hat_w_bytes = C_hat_w_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); + let byte_generators = + if let Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }) = compute_load_proof_fields { + Some((G::G2::to_le_bytes(C_hat_h3), G::G2::to_le_bytes(C_hat_w))) + } else { + None + }; + + let (C_hat_h3_bytes, C_hat_w_bytes): (&[u8], &[u8]) = + if let Some((C_hat_h3_bytes_owner, C_hat_w_bytes_owner)) = byte_generators.as_ref() { + (C_hat_h3_bytes_owner.as_ref(), C_hat_w_bytes_owner.as_ref()) + } else { + (&[], &[]) + }; let C_hat_t = G::G2::multi_mul_scalar(g_hat_list, &t); @@ -1622,10 +1662,9 @@ pub fn prove( C_h1, C_h2, C_hat_t, - C_hat_h3, - C_hat_w, pi, pi_kzg, + compute_load_proof_fields, } } @@ -1747,11 +1786,11 @@ pub fn verify( C_h1, C_h2, C_hat_t, - C_hat_h3, - C_hat_w, pi, pi_kzg, + ref compute_load_proof_fields, } = proof; + let pairing = G::Gt::pairing; let &PublicParams { @@ -1803,10 +1842,20 @@ pub fn verify( return Err(()); } - let C_hat_h3_bytes = C_hat_h3.map(G::G2::to_le_bytes); - let C_hat_w_bytes = C_hat_w.map(G::G2::to_le_bytes); - let C_hat_h3_bytes = C_hat_h3_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); - let C_hat_w_bytes = C_hat_w_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); + let byte_generators = if let Some(&ComputeLoadProofFields { C_hat_h3, C_hat_w }) = + compute_load_proof_fields.as_ref() + { + Some((G::G2::to_le_bytes(C_hat_h3), G::G2::to_le_bytes(C_hat_w))) + } else { + None + }; + + let (C_hat_h3_bytes, C_hat_w_bytes): (&[u8], &[u8]) = + if let Some((C_hat_h3_bytes_owner, C_hat_w_bytes_owner)) = byte_generators.as_ref() { + (C_hat_h3_bytes_owner.as_ref(), C_hat_w_bytes_owner.as_ref()) + } else { + (&[], &[]) + }; let x_bytes = &*[ q.to_le_bytes().as_slice(), @@ -2059,8 +2108,11 @@ pub fn verify( let lhs2 = pairing( C_r_tilde, - match C_hat_h3 { - Some(C_hat_h3) => C_hat_h3, + match compute_load_proof_fields.as_ref() { + Some(&ComputeLoadProofFields { + C_hat_h3, + C_hat_w: _, + }) => C_hat_h3, None => G::G2::multi_mul_scalar( &g_hat_list[n - (d + k)..n], &(0..d + k) @@ -2093,8 +2145,11 @@ pub fn verify( ); let lhs4 = pairing( C_e.mul_scalar(delta_e), - match C_hat_w { - Some(C_hat_w) => C_hat_w, + match compute_load_proof_fields.as_ref() { + Some(&ComputeLoadProofFields { + C_hat_h3: _, + C_hat_w, + }) => C_hat_w, None => G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &w[..d + k + 4]), }, ); @@ -2140,7 +2195,7 @@ pub fn verify( ], ); - let load = if C_hat_h3.is_some() && C_hat_w.is_some() { + let load = if compute_load_proof_fields.is_some() { ComputeLoad::Proof } else { ComputeLoad::Verify @@ -2293,10 +2348,8 @@ pub fn verify( g, { let mut C_hat = C_hat_t.mul_scalar(chi2); - if let Some(C_hat_h3) = C_hat_h3 { + if let Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }) = compute_load_proof_fields { C_hat += C_hat_h3.mul_scalar(chi3); - } - if let Some(C_hat_w) = C_hat_w { C_hat += C_hat_w.mul_scalar(chi4); } C_hat @@ -2573,7 +2626,7 @@ mod tests { rng, ); - let compressed_proof = bincode::serialize(&proof.clone().compress()).unwrap(); + let compressed_proof = bincode::serialize(&proof.compress()).unwrap(); let proof = Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap(); diff --git a/tfhe-zk-pok/src/proofs/rlwe.rs b/tfhe-zk-pok/src/proofs/rlwe.rs index 8ca9b9046..e74c6e84e 100644 --- a/tfhe-zk-pok/src/proofs/rlwe.rs +++ b/tfhe-zk-pok/src/proofs/rlwe.rs @@ -81,9 +81,14 @@ pub struct Proof { c_hat: G::G2, c_y: G::G1, pi: G::G1, - c_hat_t: Option, - c_h: Option, - pi_kzg: Option, + compute_load_proof_fields: Option>, +} + +#[derive(Clone, Debug)] +struct ComputeLoadProofFields { + c_hat_t: G::G2, + c_h: G::G1, + pi_kzg: G::G1, } pub fn crs_gen( @@ -594,18 +599,18 @@ pub fn prove( c_hat, c_y, pi, - c_hat_t: Some(c_hat_t), - c_h: Some(c_h), - pi_kzg: Some(pi_kzg), + compute_load_proof_fields: Some(ComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }), } } else { Proof { c_hat, c_y, pi, - c_hat_t: None, - c_h: None, - pi_kzg: None, + compute_load_proof_fields: None, } } } @@ -619,10 +624,9 @@ pub fn verify( c_hat, c_y, pi, - c_hat_t, - c_h, - pi_kzg, + ref compute_load_proof_fields, } = proof; + let e = G::Gt::pairing; let &PublicParams { @@ -785,7 +789,12 @@ pub fn verify( } } - if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) { + if let Some(&ComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }) = compute_load_proof_fields.as_ref() + { let mut z = G::Zp::ZERO; G::Zp::hash( core::array::from_mut(&mut z), diff --git a/tfhe/src/shortint/ciphertext/zk.rs b/tfhe/src/shortint/ciphertext/zk.rs index 569c46747..0fda8303c 100644 --- a/tfhe/src/shortint/ciphertext/zk.rs +++ b/tfhe/src/shortint/ciphertext/zk.rs @@ -228,15 +228,11 @@ impl ParameterSetConformant for ProvenCompactCiphertextList { let mut remaining_len = *total_expected_lwe_count; - for (compact_ct_list, proof) in proved_lists { + for (compact_ct_list, _proof) in proved_lists { if remaining_len == 0 { return false; } - if !proof.content_is_usable() { - return false; - } - let expected_len; if remaining_len > max_elements_per_compact_list {