diff --git a/tfhe-zk-pok/src/backward_compatibility/pke_v2.rs b/tfhe-zk-pok/src/backward_compatibility/pke_v2.rs index 70894811d..f9a14b861 100644 --- a/tfhe-zk-pok/src/backward_compatibility/pke_v2.rs +++ b/tfhe-zk-pok/src/backward_compatibility/pke_v2.rs @@ -1,11 +1,13 @@ // to follow the notation of the paper #![allow(non_snake_case)] +use std::convert::Infallible; + use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; use crate::curve_api::{CompressedG1, CompressedG2, Compressible, Curve}; use crate::proofs::pke_v2::{ - CompressedComputeLoadProofFields, CompressedProof, ComputeLoadProofFields, Proof, + CompressedComputeLoadProofFields, CompressedProof, ComputeLoadProofFields, PkeV2HashMode, Proof, }; use super::IncompleteProof; @@ -28,37 +30,107 @@ pub struct ProofV0 { C_hat_w: Option, } -impl Upgrade> for ProofV0 { +impl Upgrade> for ProofV0 { type Error = IncompleteProof; - fn upgrade(self) -> Result, Self::Error> { - let compute_load_proof_fields = match (self.C_hat_h3, self.C_hat_w) { + fn upgrade(self) -> Result, Self::Error> { + let ProofV0 { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + C_hat_h3, + C_hat_w, + } = self; + + let compute_load_proof_fields = match (C_hat_h3, C_hat_w) { (None, None) => None, (Some(C_hat_h3), Some(C_hat_w)) => Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }), _ => return Err(IncompleteProof), }; - Ok(Proof { - C_hat_e: self.C_hat_e, - C_e: self.C_e, - C_r_tilde: self.C_r_tilde, - C_R: self.C_R, - C_hat_bin: self.C_hat_bin, - C_y: self.C_y, - C_h1: self.C_h1, - C_h2: self.C_h2, - C_hat_t: self.C_hat_t, - pi: self.pi, - pi_kzg: self.pi_kzg, + Ok(ProofV1 { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, compute_load_proof_fields, }) } } +#[derive(Version)] +pub struct ProofV1 { + C_hat_e: G::G2, + C_e: G::G1, + C_r_tilde: G::G1, + C_R: G::G1, + C_hat_bin: G::G2, + C_y: G::G1, + C_h1: G::G1, + C_h2: G::G1, + C_hat_t: G::G2, + pi: G::G1, + pi_kzg: G::G1, + compute_load_proof_fields: Option>, +} + +impl Upgrade> for ProofV1 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + let ProofV1 { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + compute_load_proof_fields, + } = self; + + Ok(Proof { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + compute_load_proof_fields, + hash_mode: PkeV2HashMode::BackwardCompat, + }) + } +} + #[derive(VersionsDispatch)] pub enum ProofVersions { V0(ProofV0), - V1(Proof), + V1(ProofV1), + V2(Proof), } #[derive(VersionsDispatch)] @@ -67,6 +139,7 @@ pub(crate) enum ComputeLoadProofFieldsVersions { V0(ComputeLoadProofFields), } +#[derive(Version)] pub struct CompressedProofV0 where G::G1: Compressible, @@ -88,15 +161,31 @@ where C_hat_w: Option>, } -impl Upgrade> for CompressedProofV0 +impl Upgrade> for CompressedProofV0 where G::G1: Compressible, G::G2: Compressible, { type Error = IncompleteProof; - fn upgrade(self) -> Result, Self::Error> { - let compute_load_proof_fields = match (self.C_hat_h3, self.C_hat_w) { + fn upgrade(self) -> Result, Self::Error> { + let CompressedProofV0 { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + C_hat_h3, + C_hat_w, + } = self; + + let compute_load_proof_fields = match (C_hat_h3, C_hat_w) { (None, None) => None, (Some(C_hat_h3), Some(C_hat_w)) => { Some(CompressedComputeLoadProofFields { C_hat_h3, C_hat_w }) @@ -104,23 +193,84 @@ where _ => return Err(IncompleteProof), }; - Ok(CompressedProof { - C_hat_e: self.C_hat_e, - C_e: self.C_e, - C_r_tilde: self.C_r_tilde, - C_R: self.C_R, - C_hat_bin: self.C_hat_bin, - C_y: self.C_y, - C_h1: self.C_h1, - C_h2: self.C_h2, - C_hat_t: self.C_hat_t, - pi: self.pi, - pi_kzg: self.pi_kzg, + Ok(CompressedProofV1 { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, compute_load_proof_fields, }) } } +#[derive(Version)] +pub struct CompressedProofV1 +where + G::G1: Compressible, + G::G2: Compressible, +{ + C_hat_e: CompressedG2, + C_e: CompressedG1, + C_r_tilde: CompressedG1, + C_R: CompressedG1, + C_hat_bin: CompressedG2, + C_y: CompressedG1, + C_h1: CompressedG1, + C_h2: CompressedG1, + C_hat_t: CompressedG2, + pi: CompressedG1, + pi_kzg: CompressedG1, + compute_load_proof_fields: Option>, +} + +impl Upgrade> for CompressedProofV1 +where + G::G1: Compressible, + G::G2: Compressible, +{ + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + let CompressedProofV1 { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + compute_load_proof_fields, + } = self; + + Ok(CompressedProof { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + compute_load_proof_fields, + hash_mode: PkeV2HashMode::BackwardCompat, + }) + } +} + #[derive(VersionsDispatch)] pub enum CompressedProofVersions where @@ -140,3 +290,9 @@ where #[allow(dead_code)] V0(CompressedComputeLoadProofFields), } + +#[derive(VersionsDispatch)] +pub(crate) enum PkeV2HashModeVersions { + #[allow(dead_code)] + V0(PkeV2HashMode), +} diff --git a/tfhe-zk-pok/src/proofs/pke_v2/hashes.rs b/tfhe-zk-pok/src/proofs/pke_v2/hashes.rs index 7b15a83e1..e29a096f0 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2/hashes.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2/hashes.rs @@ -1,11 +1,137 @@ +use std::iter::successors; + +use serde::{Deserialize, Serialize}; +use tfhe_versionable::Versionize; + /// Scalar generation using the hash random oracle use crate::{ + backward_compatibility::pke_v2::PkeV2HashModeVersions, curve_api::{Curve, FieldOps}, proofs::pke_v2::{compute_crs_params, inf_norm_bound_to_euclidean_squared}, }; use super::{PKEv2DomainSeparators, PublicCommit, PublicParams}; +/// Generates the vector `[1, y, y^2, y^3, ...]` from y +fn generate_powers(scalar: Zp, out: &mut [Zp]) { + let powers_iterator = successors(Some(scalar), move |prev| Some(*prev * scalar)); + + if let Some(val0) = out.get_mut(0) { + *val0 = Zp::ONE; + } + + for (val, power) in out[1..].iter_mut().zip(powers_iterator) { + *val = power; + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Versionize)] +#[versionize(PkeV2HashModeVersions)] +/// Defines how the hash functions will be used to generate values +pub(crate) enum PkeV2HashMode { + /// Compatibility with proofs generated with tfhe-zk-pok 0.6.0 and earlier + BackwardCompat, + /// The basic PkeV2 scheme without the hashes optimizations + Classical, + /// Reduce the number of hashed bytes with various optimizations: + /// - generates only y1 as a hash and derives y = [1, y1, y1^2,...] + /// - only hash R in phi + Compact, +} + +impl PkeV2HashMode { + /// Generate a list of scalars using the hash random oracle. The generated hashes are written to + /// the `output` slice and a byte representation is returned + fn gen_scalars_with_hash( + self, + mut output: &mut [Zp], + inputs: &[&[u8]], + hash_fn: impl FnOnce(&mut [Zp], &[&[u8]]), + ) -> Box<[u8]> { + let mut scalar1 = Zp::ZERO; + + let scalars_gen = match self { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => &mut output, + PkeV2HashMode::Compact => core::slice::from_mut(&mut scalar1), + }; + + hash_fn(scalars_gen, inputs); + + match self { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => output + .iter() + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) + .collect::>(), + PkeV2HashMode::Compact => { + generate_powers(scalar1, output); + // since the content of the list is entirely defined by scalar1, it is not necessary + // to hash the full list in the following steps + Box::from(scalar1.to_le_bytes().as_ref()) + } + } + } + + fn gen_scalars(self, output: &mut [Zp], inputs: &[&[u8]]) -> Box<[u8]> { + self.gen_scalars_with_hash(output, inputs, Zp::hash) + } + + /// Generates 128bits scalars that reduce the cost of multi exponentiations. This is + /// not compatible with compact hashes since the scalars need to be independent, so a classical + /// hash function should be used. + /// + /// # panic + /// panics if self is `PkeV2HashMode::Compact` + fn gen_scalars_128b(self, output: &mut [Zp], inputs: &[&[u8]]) -> Box<[u8]> { + if !self.supports_128b_scalars() { + panic!("128b scalars optimization cannot be used in compact hash mode") + }; + self.gen_scalars_with_hash(output, inputs, Zp::hash_128bit) + } + + /// Checks if the hashing mode can be used with `gen_scalars_128` + fn supports_128b_scalars(self) -> bool { + match self { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => true, + PkeV2HashMode::Compact => false, + } + } + + /// Encode the R matrix (defined as a matrix of -1, 0, 1) as bytes. + fn encode_R(self, R: &[i8]) -> Box<[u8]> { + // The representation is not specified in the mathematical description, so we are free to + // chose a compact one as long as it is injective + match self { + PkeV2HashMode::BackwardCompat => { + // Basic representation where each value is stored in a byte + let R_coeffs = |i: usize, j: usize| R[i + j * 128]; + let columns = R.len() / 128; + + (0..128) + .flat_map(|i| (0..columns).map(move |j| R_coeffs(i, j) as u8)) + .collect() + } + PkeV2HashMode::Compact | PkeV2HashMode::Classical => { + // Since the R matrix is only composed of ternary values, we can pack them by group + // of five instead of using a full u8 for each value + R.chunks(5) + .map(|chunk| { + let mut packed: u8 = 0; + let mut power_of_3: u8 = 1; + + // Cannot overflow since the max value is 3**5 = 243, which fits in a byte + for &byte in chunk { + let mapped = (byte + 1) as u8; + packed += mapped * power_of_3; + power_of_3 *= 3; + } + packed + }) + .collect() + } + } + } +} + // The scalar used for the proof are generated using sha3 as a random oracle. The inputs of the hash // that generates a given scalar are reused for the subsequent hashes. We use the typestate pattern // to propagate the inputs from one hash to the next. @@ -23,6 +149,7 @@ struct RInputs<'a> { n: usize, k: usize, d: usize, + mode: PkeV2HashMode, } pub(super) struct RHash<'a> { @@ -37,6 +164,7 @@ impl<'a> RHash<'a> { C_hat_e_bytes: &'a [u8], C_e_bytes: &'a [u8], C_r_tilde_bytes: &'a [u8], + mode: PkeV2HashMode, ) -> (Box<[i8]>, Self) { let ( &PublicParams { @@ -129,10 +257,7 @@ impl<'a> RHash<'a> { }) .collect::>(); - let R_coeffs = |i: usize, j: usize| R[i + j * 128]; - let R_bytes = (0..128) - .flat_map(|i| (0..(2 * (d + k) + 4)).map(move |j| R_coeffs(i, j) as u8)) - .collect(); + let R_bytes = mode.encode_R(&R); ( R, @@ -150,6 +275,7 @@ impl<'a> RHash<'a> { n, k, d, + mode, }, R_bytes, @@ -157,35 +283,36 @@ impl<'a> RHash<'a> { ) } - pub(super) fn gen_phi(self, C_R_bytes: &'a [u8]) -> ([Zp; 128], PhiHash<'a>) { + fn phi_hash_inputs(&self, phi_inputs: &PhiInputs<'a>) -> [&[u8]; 9] { let Self { R_inputs, R_bytes } = self; + [ + R_inputs.ds.hash_phi(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + R_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + phi_inputs.C_R_bytes, + R_inputs.C_r_tilde_bytes, + ] + } + + pub(super) fn gen_phi(self, C_R_bytes: &'a [u8]) -> ([Zp; 128], PhiHash<'a>) { + let mode = self.R_inputs.mode; + let phi_inputs = PhiInputs { C_R_bytes }; + let mut phi = [Zp::ZERO; 128]; - Zp::hash( - &mut phi, - &[ - R_inputs.ds.hash_phi(), - &R_inputs.sid_bytes, - R_inputs.metadata, - &R_inputs.x_bytes, - &R_bytes, - R_inputs.C_hat_e_bytes, - R_inputs.C_e_bytes, - C_R_bytes, - R_inputs.C_r_tilde_bytes, - ], - ); - let phi_bytes = phi - .iter() - .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) - .collect::>(); + + let phi_bytes = mode.gen_scalars(&mut phi, &self.phi_hash_inputs(&phi_inputs)); ( phi, PhiHash { - R_inputs, - phi_inputs: PhiInputs { C_R_bytes }, - R_bytes, + R_inputs: self.R_inputs, + phi_inputs, + R_bytes: self.R_bytes, phi_bytes, }, ) @@ -204,7 +331,7 @@ pub(super) struct PhiHash<'a> { } impl<'a> PhiHash<'a> { - pub(super) fn gen_xi(self, C_hat_bin_bytes: &'a [u8]) -> ([Zp; 128], XiHash<'a>) { + fn xi_hash_inputs(&self, xi_inputs: &XiInputs<'a>) -> [&[u8]; 11] { let Self { R_inputs, R_bytes, @@ -212,37 +339,52 @@ impl<'a> PhiHash<'a> { phi_bytes, } = self; - let mut xi = [Zp::ZERO; 128]; - Zp::hash( - &mut xi, - &[ + match R_inputs.mode { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => [ R_inputs.ds.hash_xi(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, - &R_bytes, - &phi_bytes, + R_bytes, + phi_bytes, phi_inputs.C_R_bytes, - C_hat_bin_bytes, + xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, ], - ); + PkeV2HashMode::Compact => [ + R_inputs.ds.hash_xi(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + &[], // R is only hashed in phi in compact mode + phi_bytes, + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + ], + } + } - let xi_bytes = xi - .iter() - .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) - .collect::>(); + pub(super) fn gen_xi(self, C_hat_bin_bytes: &'a [u8]) -> ([Zp; 128], XiHash<'a>) { + let mode = self.R_inputs.mode; + let xi_inputs = XiInputs { C_hat_bin_bytes }; + + let mut xi = [Zp::ZERO; 128]; + + let xi_bytes = mode.gen_scalars(&mut xi, &self.xi_hash_inputs(&xi_inputs)); ( xi, XiHash { - R_inputs, - R_bytes, - phi_inputs, - phi_bytes, - xi_inputs: XiInputs { C_hat_bin_bytes }, + R_inputs: self.R_inputs, + R_bytes: self.R_bytes, + phi_inputs: self.phi_inputs, + phi_bytes: self.phi_bytes, + xi_inputs, xi_bytes, }, ) @@ -263,7 +405,7 @@ pub(super) struct XiHash<'a> { } impl<'a> XiHash<'a> { - pub(super) fn gen_y(self) -> (Vec, YHash<'a>) { + fn y_hash_inputs(&self) -> [&[u8]; 12] { let Self { R_inputs, R_bytes, @@ -273,38 +415,53 @@ impl<'a> XiHash<'a> { xi_bytes, } = self; - let mut y = vec![Zp::ZERO; R_inputs.D + 128 * R_inputs.m]; - Zp::hash( - &mut y, - &[ + match R_inputs.mode { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => [ R_inputs.ds.hash(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, - &R_bytes, - &phi_bytes, - &xi_bytes, + R_bytes, + phi_bytes, + xi_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, phi_inputs.C_R_bytes, xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, ], - ); - let y_bytes = y - .iter() - .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) - .collect::>(); + PkeV2HashMode::Compact => [ + R_inputs.ds.hash(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + &[], // R is only hashed in phi in compact mode + phi_bytes, + xi_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + ], + } + } + + pub(super) fn gen_y(self) -> (Vec, YHash<'a>) { + let mode = self.R_inputs.mode; + + let mut y = vec![Zp::ZERO; self.R_inputs.D + 128 * self.R_inputs.m]; + let y_bytes = mode.gen_scalars(&mut y, &self.y_hash_inputs()); ( y, YHash { - R_inputs, - R_bytes, - phi_inputs, - phi_bytes, - xi_inputs, - xi_bytes, + R_inputs: self.R_inputs, + R_bytes: self.R_bytes, + phi_inputs: self.phi_inputs, + phi_bytes: self.phi_bytes, + xi_inputs: self.xi_inputs, + xi_bytes: self.xi_bytes, y_bytes, }, ) @@ -322,7 +479,7 @@ pub(super) struct YHash<'a> { } impl<'a> YHash<'a> { - pub(super) fn gen_t(self, C_y_bytes: &'a [u8]) -> (Vec, THash<'a>) { + fn t_hash_input(&self, t_inputs: &TInputs<'a>) -> [&[u8]; 14] { let Self { R_inputs, R_bytes, @@ -333,42 +490,64 @@ impl<'a> YHash<'a> { y_bytes, } = self; - let mut t = vec![Zp::ZERO; R_inputs.n]; - Zp::hash_128bit( - &mut t, - &[ + match R_inputs.mode { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => [ R_inputs.ds.hash_t(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, - &y_bytes, - &phi_bytes, - &xi_bytes, + y_bytes, + phi_bytes, + xi_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, - &R_bytes, + R_bytes, phi_inputs.C_R_bytes, xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, - C_y_bytes, + t_inputs.C_y_bytes, ], - ); - let t_bytes = t - .iter() - .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) - .collect::>(); + PkeV2HashMode::Compact => [ + R_inputs.ds.hash_t(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + phi_bytes, + xi_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + &[], // R is only hashed in phi in compact mode + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + ], + } + } + + pub(super) fn gen_t(self, C_y_bytes: &'a [u8]) -> (Vec, THash<'a>) { + let mode = self.R_inputs.mode; + let t_inputs = TInputs { C_y_bytes }; + + let mut t = vec![Zp::ZERO; self.R_inputs.n]; + let t_bytes = if mode.supports_128b_scalars() { + mode.gen_scalars_128b(&mut t, &self.t_hash_input(&t_inputs)) + } else { + mode.gen_scalars(&mut t, &self.t_hash_input(&t_inputs)) + }; ( t, THash { - R_inputs, - R_bytes, - phi_inputs, - phi_bytes, - xi_inputs, - xi_bytes, - y_bytes, - t_inputs: TInputs { C_y_bytes }, + R_inputs: self.R_inputs, + R_bytes: self.R_bytes, + phi_inputs: self.phi_inputs, + phi_bytes: self.phi_bytes, + xi_inputs: self.xi_inputs, + xi_bytes: self.xi_bytes, + y_bytes: self.y_bytes, + t_inputs, t_bytes, }, ) @@ -392,7 +571,7 @@ pub(super) struct THash<'a> { } impl<'a> THash<'a> { - pub(super) fn gen_theta(self) -> (Vec, ThetaHash<'a>) { + fn theta_hash_input(&self) -> [&[u8]; 15] { let Self { R_inputs, phi_inputs, @@ -405,44 +584,62 @@ impl<'a> THash<'a> { y_bytes, } = self; - let mut theta = vec![Zp::ZERO; R_inputs.d + R_inputs.k]; - Zp::hash( - &mut theta, - &[ + match R_inputs.mode { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => [ R_inputs.ds.hash_lmap(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, - &y_bytes, - &t_bytes, - &phi_bytes, - &xi_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, - &R_bytes, + R_bytes, phi_inputs.C_R_bytes, xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, t_inputs.C_y_bytes, ], - ); - let theta_bytes = theta - .iter() - .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) - .collect::>(); + PkeV2HashMode::Compact => [ + R_inputs.ds.hash_lmap(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + &[], // R is only hashed in phi in compact mode + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + ], + } + } + + pub(super) fn gen_theta(self) -> (Vec, ThetaHash<'a>) { + let mode = self.R_inputs.mode; + + let mut theta = vec![Zp::ZERO; self.R_inputs.d + self.R_inputs.k]; + let theta_bytes = mode.gen_scalars(&mut theta, &self.theta_hash_input()); ( theta, ThetaHash { - R_inputs, - R_bytes, - phi_inputs, - phi_bytes, - xi_inputs, - xi_bytes, - y_bytes, - t_inputs, - t_bytes, + R_inputs: self.R_inputs, + R_bytes: self.R_bytes, + phi_inputs: self.phi_inputs, + phi_bytes: self.phi_bytes, + xi_inputs: self.xi_inputs, + xi_bytes: self.xi_bytes, + y_bytes: self.y_bytes, + t_inputs: self.t_inputs, + t_bytes: self.t_bytes, theta_bytes, }, ) @@ -463,7 +660,7 @@ pub(super) struct ThetaHash<'a> { } impl<'a> ThetaHash<'a> { - pub(super) fn gen_omega(self) -> (Vec, OmegaHash<'a>) { + fn omega_hash_input(&self) -> [&[u8]; 16] { let Self { R_inputs, R_bytes, @@ -477,46 +674,69 @@ impl<'a> ThetaHash<'a> { theta_bytes, } = self; - let mut omega = vec![Zp::ZERO; R_inputs.n]; - Zp::hash_128bit( - &mut omega, - &[ + match self.R_inputs.mode { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => [ R_inputs.ds.hash_w(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, - &y_bytes, - &t_bytes, - &phi_bytes, - &xi_bytes, - &theta_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, - &R_bytes, + R_bytes, phi_inputs.C_R_bytes, xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, t_inputs.C_y_bytes, ], - ); - let omega_bytes = omega - .iter() - .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) - .collect::>(); + PkeV2HashMode::Compact => [ + R_inputs.ds.hash_w(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + &[], // R is only hashed in phi in compact mode + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + ], + } + } + + pub(super) fn gen_omega(self) -> (Vec, OmegaHash<'a>) { + let mode = self.R_inputs.mode; + + let mut omega = vec![Zp::ZERO; self.R_inputs.n]; + let omega_bytes = if mode.supports_128b_scalars() { + mode.gen_scalars_128b(&mut omega, &self.omega_hash_input()) + } else { + mode.gen_scalars(&mut omega, &self.omega_hash_input()) + }; ( omega, OmegaHash { - R_inputs, - R_bytes, - phi_inputs, - phi_bytes, - xi_inputs, - xi_bytes, - y_bytes, - t_inputs, - t_bytes, - theta_bytes, + R_inputs: self.R_inputs, + R_bytes: self.R_bytes, + phi_inputs: self.phi_inputs, + phi_bytes: self.phi_bytes, + xi_inputs: self.xi_inputs, + xi_bytes: self.xi_bytes, + y_bytes: self.y_bytes, + t_inputs: self.t_inputs, + t_bytes: self.t_bytes, + theta_bytes: self.theta_bytes, omega_bytes, }, ) @@ -538,7 +758,7 @@ pub(super) struct OmegaHash<'a> { } impl<'a> OmegaHash<'a> { - pub(super) fn gen_delta(self) -> ([Zp; 7], DeltaHash<'a>) { + fn delta_hash_input(&self) -> [&[u8]; 17] { let Self { R_inputs, R_bytes, @@ -553,29 +773,53 @@ impl<'a> OmegaHash<'a> { omega_bytes, } = self; - let mut delta = [Zp::ZERO; 7]; - Zp::hash( - &mut delta, - &[ + match self.R_inputs.mode { + PkeV2HashMode::BackwardCompat | PkeV2HashMode::Classical => [ R_inputs.ds.hash_agg(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, - &y_bytes, - &t_bytes, - &phi_bytes, - &xi_bytes, - &theta_bytes, - &omega_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + omega_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, - &R_bytes, + R_bytes, phi_inputs.C_R_bytes, xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, t_inputs.C_y_bytes, ], - ); + PkeV2HashMode::Compact => [ + R_inputs.ds.hash_agg(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + omega_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + &[], // R is only hashed in phi in compact mode + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + ], + } + } + + pub(super) fn gen_delta(self) -> ([Zp; 7], DeltaHash<'a>) { + let mut delta = [Zp::ZERO; 7]; + + // Delta does not use the compact hash optimization + Zp::hash(&mut delta, &self.delta_hash_input()); let delta_bytes = delta .iter() .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) @@ -584,16 +828,17 @@ impl<'a> OmegaHash<'a> { ( delta, DeltaHash { - R_inputs, - R_bytes, - phi_inputs, - phi_bytes, - xi_inputs, - xi_bytes, - y_bytes, - t_inputs, - t_bytes, - theta_bytes, + R_inputs: self.R_inputs, + R_bytes: self.R_bytes, + phi_inputs: self.phi_inputs, + phi_bytes: self.phi_bytes, + xi_inputs: self.xi_inputs, + xi_bytes: self.xi_bytes, + y_bytes: self.y_bytes, + t_inputs: self.t_inputs, + t_bytes: self.t_bytes, + theta_bytes: self.theta_bytes, + omega_bytes: self.omega_bytes, delta_bytes, }, ) @@ -611,18 +856,12 @@ pub(super) struct DeltaHash<'a> { t_inputs: TInputs<'a>, t_bytes: Box<[u8]>, theta_bytes: Box<[u8]>, + omega_bytes: Box<[u8]>, delta_bytes: Box<[u8]>, } impl<'a> DeltaHash<'a> { - pub(super) fn gen_z( - self, - C_h1_bytes: &'a [u8], - C_h2_bytes: &'a [u8], - C_hat_t_bytes: &'a [u8], - C_hat_h3_bytes: &'a [u8], - C_hat_omega_bytes: &'a [u8], - ) -> (Zp, ZHash<'a>) { + fn z_hash_input(&self, z_inputs: &ZInputs<'a>) -> [&[u8]; 23] { let Self { R_inputs, R_bytes, @@ -634,59 +873,126 @@ impl<'a> DeltaHash<'a> { t_inputs, t_bytes, theta_bytes, + omega_bytes, delta_bytes, } = self; - let mut z = Zp::ZERO; - Zp::hash( - core::slice::from_mut(&mut z), - &[ + match R_inputs.mode { + PkeV2HashMode::BackwardCompat => { + [ + R_inputs.ds.hash_z(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + t_bytes, + phi_bytes, + &R_inputs.x_bytes, // x is duplicated but we keep it for backward compat + theta_bytes, + &[], // Omega is not included for backward compat + delta_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + R_bytes, + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + z_inputs.C_h1_bytes, + z_inputs.C_h2_bytes, + z_inputs.C_hat_t_bytes, + z_inputs.C_hat_h3_bytes, + z_inputs.C_hat_omega_bytes, + ] + } + PkeV2HashMode::Classical => [ R_inputs.ds.hash_z(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, - &y_bytes, - &t_bytes, - &phi_bytes, - &R_inputs.x_bytes, // x is duplicated but we have to keep it for backward compat - &theta_bytes, - &delta_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + omega_bytes, + delta_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, - &R_bytes, + R_bytes, phi_inputs.C_R_bytes, xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, t_inputs.C_y_bytes, - C_h1_bytes, - C_h2_bytes, - C_hat_t_bytes, - C_hat_h3_bytes, - C_hat_omega_bytes, + z_inputs.C_h1_bytes, + z_inputs.C_h2_bytes, + z_inputs.C_hat_t_bytes, + z_inputs.C_hat_h3_bytes, + z_inputs.C_hat_omega_bytes, ], - ); + PkeV2HashMode::Compact => [ + R_inputs.ds.hash_z(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + omega_bytes, + delta_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + &[], // R is only hashed in phi in compact mode + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + z_inputs.C_h1_bytes, + z_inputs.C_h2_bytes, + z_inputs.C_hat_t_bytes, + z_inputs.C_hat_h3_bytes, + z_inputs.C_hat_omega_bytes, + ], + } + } + + pub(super) fn gen_z( + self, + C_h1_bytes: &'a [u8], + C_h2_bytes: &'a [u8], + C_hat_t_bytes: &'a [u8], + C_hat_h3_bytes: &'a [u8], + C_hat_omega_bytes: &'a [u8], + ) -> (Zp, ZHash<'a>) { + let z_inputs = ZInputs { + C_h1_bytes, + C_h2_bytes, + C_hat_t_bytes, + C_hat_h3_bytes, + C_hat_omega_bytes, + }; + + let mut z = Zp::ZERO; + Zp::hash(core::slice::from_mut(&mut z), &self.z_hash_input(&z_inputs)); ( z, ZHash { - R_inputs, - R_bytes, - phi_inputs, - phi_bytes, - xi_inputs, - xi_bytes, - y_bytes, - t_inputs, - t_bytes, - theta_bytes, - delta_bytes, - z_inputs: ZInputs { - C_h1_bytes, - C_h2_bytes, - C_hat_t_bytes, - C_hat_h3_bytes, - C_hat_omega_bytes, - }, + R_inputs: self.R_inputs, + R_bytes: self.R_bytes, + phi_inputs: self.phi_inputs, + phi_bytes: self.phi_bytes, + xi_inputs: self.xi_inputs, + xi_bytes: self.xi_bytes, + y_bytes: self.y_bytes, + t_inputs: self.t_inputs, + t_bytes: self.t_bytes, + theta_bytes: self.theta_bytes, + omega_bytes: self.omega_bytes, + delta_bytes: self.delta_bytes, + z_inputs, z_bytes: Box::from(z.to_le_bytes().as_ref()), }, ) @@ -712,13 +1018,21 @@ pub(super) struct ZHash<'a> { t_inputs: TInputs<'a>, t_bytes: Box<[u8]>, theta_bytes: Box<[u8]>, + omega_bytes: Box<[u8]>, delta_bytes: Box<[u8]>, z_inputs: ZInputs<'a>, z_bytes: Box<[u8]>, } impl<'a> ZHash<'a> { - pub(super) fn gen_chi(self, p_h1: Zp, p_h2: Zp, p_t: Zp) -> Zp { + fn chi_hash_input<'b>( + &'b self, + p_h1: &'b [u8], + p_h2: &'b [u8], + p_t: &'b [u8], + p_h3: &'b [u8], + p_omega: &'b [u8], + ) -> [&'b [u8]; 29] { let Self { R_inputs, R_bytes, @@ -730,28 +1044,28 @@ impl<'a> ZHash<'a> { t_inputs, t_bytes, theta_bytes, + omega_bytes, delta_bytes, z_inputs, z_bytes, } = self; - let mut chi = Zp::ZERO; - Zp::hash( - core::slice::from_mut(&mut chi), - &[ + match R_inputs.mode { + PkeV2HashMode::BackwardCompat => [ R_inputs.ds.hash_chi(), &R_inputs.sid_bytes, R_inputs.metadata, &R_inputs.x_bytes, - &y_bytes, - &t_bytes, - &phi_bytes, - &xi_bytes, - &theta_bytes, - &delta_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + &[], // Omega is not included for backward compat + delta_bytes, R_inputs.C_hat_e_bytes, R_inputs.C_e_bytes, - &R_bytes, + R_bytes, phi_inputs.C_R_bytes, xi_inputs.C_hat_bin_bytes, R_inputs.C_r_tilde_bytes, @@ -761,11 +1075,103 @@ impl<'a> ZHash<'a> { z_inputs.C_hat_t_bytes, z_inputs.C_hat_h3_bytes, z_inputs.C_hat_omega_bytes, - &z_bytes, + z_bytes, + p_h1, + p_h2, + p_t, + // p_h3 and p_omega are not hashed for backward compatibility reasons + &[], + &[], + ], + PkeV2HashMode::Classical => [ + R_inputs.ds.hash_chi(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + omega_bytes, + delta_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + R_bytes, + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + z_inputs.C_h1_bytes, + z_inputs.C_h2_bytes, + z_inputs.C_hat_t_bytes, + z_inputs.C_hat_h3_bytes, + z_inputs.C_hat_omega_bytes, + z_bytes, + p_h1, + p_h2, + p_t, + p_h3, + p_omega, + ], + PkeV2HashMode::Compact => [ + R_inputs.ds.hash_chi(), + &R_inputs.sid_bytes, + R_inputs.metadata, + &R_inputs.x_bytes, + y_bytes, + t_bytes, + phi_bytes, + xi_bytes, + theta_bytes, + omega_bytes, + delta_bytes, + R_inputs.C_hat_e_bytes, + R_inputs.C_e_bytes, + &[], // R is only hashed in phi in compact mode + phi_inputs.C_R_bytes, + xi_inputs.C_hat_bin_bytes, + R_inputs.C_r_tilde_bytes, + t_inputs.C_y_bytes, + z_inputs.C_h1_bytes, + z_inputs.C_h2_bytes, + z_inputs.C_hat_t_bytes, + z_inputs.C_hat_h3_bytes, + z_inputs.C_hat_omega_bytes, + z_bytes, + p_h1, + p_h2, + p_t, + p_h3, + p_omega, + ], + } + } + + pub(super) fn gen_chi( + self, + p_h1: Zp, + p_h2: Zp, + p_t: Zp, + p_h3_opt: Option, + p_omega_opt: Option, + ) -> Zp { + let mut chi = Zp::ZERO; + + let p_h3 = p_h3_opt.map_or(Box::from([]), |p_h3| Box::from(p_h3.to_le_bytes().as_ref())); + let p_omega = p_omega_opt.map_or(Box::from([]), |p_omega| { + Box::from(p_omega.to_le_bytes().as_ref()) + }); + + Zp::hash( + core::slice::from_mut(&mut chi), + &self.chi_hash_input( p_h1.to_le_bytes().as_ref(), p_h2.to_le_bytes().as_ref(), p_t.to_le_bytes().as_ref(), - ], + &p_h3, + &p_omega, + ), ); chi diff --git a/tfhe-zk-pok/src/proofs/pke_v2/mod.rs b/tfhe-zk-pok/src/proofs/pke_v2/mod.rs index e306c5c42..0bca66c41 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2/mod.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2/mod.rs @@ -12,6 +12,7 @@ use crate::serialization::{ }; use core::marker::PhantomData; + use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -19,6 +20,8 @@ mod hashes; use hashes::RHash; +pub(crate) use hashes::PkeV2HashMode; + fn bit_iter(x: u64, nbits: u32) -> impl Iterator { (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) } @@ -366,8 +369,8 @@ pub struct Proof { pub(crate) C_hat_t: G::G2, pub(crate) pi: G::G1, pub(crate) pi_kzg: G::G1, - pub(crate) compute_load_proof_fields: Option>, + pub(crate) hash_mode: PkeV2HashMode, } impl Proof { @@ -390,6 +393,7 @@ impl Proof { pi, pi_kzg, ref compute_load_proof_fields, + hash_mode: _, } = self; C_hat_e.validate_projective() @@ -455,8 +459,8 @@ where pub(crate) C_hat_t: CompressedG2, pub(crate) pi: CompressedG1, pub(crate) pi_kzg: CompressedG1, - pub(crate) compute_load_proof_fields: Option>, + pub(crate) hash_mode: PkeV2HashMode, } #[derive(Serialize, Deserialize, Versionize)] @@ -497,6 +501,7 @@ where pi, pi_kzg, compute_load_proof_fields, + hash_mode, } = self; CompressedProof { @@ -518,6 +523,7 @@ where C_hat_w: C_hat_w.compress(), }, ), + hash_mode: *hash_mode, } } @@ -535,6 +541,7 @@ where pi, pi_kzg, compute_load_proof_fields, + hash_mode, } = compressed; Ok(Proof { @@ -562,6 +569,7 @@ where } else { None }, + hash_mode, }) } } @@ -812,6 +820,7 @@ pub fn prove( metadata, load, seed, + PkeV2HashMode::Compact, ProofSanityCheckMode::Panic, ) } @@ -822,6 +831,7 @@ fn prove_impl( metadata: &[u8], load: ComputeLoad, seed: &[u8], + hash_mode: PkeV2HashMode, sanity_check_mode: ProofSanityCheckMode, ) -> Proof { _ = load; @@ -966,6 +976,7 @@ fn prove_impl( C_hat_e_bytes.as_ref(), C_e_bytes.as_ref(), C_r_tilde_bytes.as_ref(), + hash_mode, ); let R = |i: usize, j: usize| R[i + j * 128]; @@ -1430,7 +1441,7 @@ fn prove_impl( ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + n], ComputeLoad::Verify => vec![], }; - let mut P_w = match load { + let mut P_omega = match load { ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + d + k + 4], ComputeLoad::Verify => vec![], }; @@ -1502,15 +1513,15 @@ fn prove_impl( } } - if !P_w.is_empty() { - P_w[1..].copy_from_slice(&omega[..d + k + 4]); + if !P_omega.is_empty() { + P_omega[1..].copy_from_slice(&omega[..d + k + 4]); } let mut p_h1 = G::Zp::ZERO; let mut p_h2 = G::Zp::ZERO; let mut p_t = G::Zp::ZERO; let mut p_h3 = G::Zp::ZERO; - let mut p_w = G::Zp::ZERO; + let mut p_omega = G::Zp::ZERO; let mut pow = G::Zp::ONE; for j in 0..n + 1 { @@ -1521,14 +1532,21 @@ fn prove_impl( if j < P_h3.len() { p_h3 += P_h3[j] * pow; } - if j < P_w.len() { - p_w += P_w[j] * pow; + if j < P_omega.len() { + p_omega += P_omega[j] * pow; } pow = pow * z; } - let chi = z_hash.gen_chi(p_h1, p_h2, p_t); + let p_h3_opt = if P_h3.is_empty() { None } else { Some(p_h3) }; + let p_omega_opt = if P_omega.is_empty() { + None + } else { + Some(p_omega) + }; + + let chi = z_hash.gen_chi(p_h1, p_h2, p_t, p_h3_opt, p_omega_opt); let mut Q_kzg = vec![G::Zp::ZERO; 1 + n]; let chi2 = chi * chi; @@ -1539,11 +1557,11 @@ fn prove_impl( if j < P_h3.len() { Q_kzg[j] += chi3 * P_h3[j]; } - if j < P_w.len() { - Q_kzg[j] += chi4 * P_w[j]; + if j < P_omega.len() { + Q_kzg[j] += chi4 * P_omega[j]; } } - Q_kzg[0] -= p_h1 + chi * p_h2 + chi2 * p_t + chi3 * p_h3 + chi4 * p_w; + Q_kzg[0] -= p_h1 + chi * p_h2 + chi2 * p_t + chi3 * p_h3 + chi4 * p_omega; // https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode let mut q = vec![G::Zp::ZERO; n]; @@ -1568,6 +1586,7 @@ fn prove_impl( pi, pi_kzg, compute_load_proof_fields, + hash_mode, } } @@ -1703,6 +1722,7 @@ pub fn verify_impl( pi, pi_kzg, ref compute_load_proof_fields, + hash_mode, } = proof; let pairing = G::Gt::pairing; @@ -1766,6 +1786,7 @@ pub fn verify_impl( C_hat_e_bytes.as_ref(), C_e_bytes.as_ref(), C_r_tilde_bytes.as_ref(), + hash_mode, ); let R = |i: usize, j: usize| R[i + j * 128]; @@ -1896,7 +1917,7 @@ pub fn verify_impl( ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + n], ComputeLoad::Verify => vec![], }; - let mut P_w = match load { + let mut P_omega = match load { ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + d + k + 4], ComputeLoad::Verify => vec![], }; @@ -1968,15 +1989,15 @@ pub fn verify_impl( } } - if !P_w.is_empty() { - P_w[1..].copy_from_slice(&omega[..d + k + 4]); + if !P_omega.is_empty() { + P_omega[1..].copy_from_slice(&omega[..d + k + 4]); } let mut p_h1 = G::Zp::ZERO; let mut p_h2 = G::Zp::ZERO; let mut p_t = G::Zp::ZERO; let mut p_h3 = G::Zp::ZERO; - let mut p_w = G::Zp::ZERO; + let mut p_omega = G::Zp::ZERO; let mut pow = G::Zp::ONE; for j in 0..n + 1 { @@ -1987,14 +2008,21 @@ pub fn verify_impl( if j < P_h3.len() { p_h3 += P_h3[j] * pow; } - if j < P_w.len() { - p_w += P_w[j] * pow; + if j < P_omega.len() { + p_omega += P_omega[j] * pow; } pow = pow * z; } - let chi = z_hash.gen_chi(p_h1, p_h2, p_t); + let p_h3_opt = if P_h3.is_empty() { None } else { Some(p_h3) }; + let p_omega_opt = if P_omega.is_empty() { + None + } else { + Some(p_omega) + }; + + let chi = z_hash.gen_chi(p_h1, p_h2, p_t, p_h3_opt, p_omega_opt); let chi2 = chi * chi; let chi3 = chi2 * chi; @@ -2012,7 +2040,7 @@ pub fn verify_impl( C_hat += C_hat_w.mul_scalar(chi4); } C_hat - } - g_hat.mul_scalar(p_t * chi2 + p_h3 * chi3 + p_w * chi4), + } - g_hat.mul_scalar(p_t * chi2 + p_h3 * chi3 + p_omega * chi4), ); let rhs = pairing( pi_kzg, @@ -2200,12 +2228,122 @@ mod tests { } } + #[test] + fn test_pke_legacy_hash() { + let PkeTestParameters { + d, + k, + B, + q, + t, + msbs_zero_padding_bit_count, + } = PKEV2_TEST_PARAMS; + + let effective_cleartext_t = t >> msbs_zero_padding_bit_count; + + let seed = thread_rng().gen(); + println!("pkev2 legacy hash seed: {seed:x}"); + let rng = &mut StdRng::seed_from_u64(seed); + + let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); + let ct = testcase.encrypt(PKEV2_TEST_PARAMS); + + let fake_e1 = (0..d) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + let fake_e2 = (0..k) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + + let fake_r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + + let fake_m = (0..k) + .map(|_| (rng.gen::() % effective_cleartext_t) as i64) + .collect::>(); + + let mut fake_metadata = [255u8; METADATA_LEN]; + fake_metadata.fill_with(|| rng.gen::()); + + // To check management of bigger k_max from CRS during test + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); + + for (use_fake_r, use_fake_e1, use_fake_e2, use_fake_m, use_fake_metadata_verify) in itertools::iproduct!( + [false, true], + [false, true], + [false, true], + [false, true], + [false, true] + ) { + let (public_commit, private_commit) = commit( + testcase.a.clone(), + testcase.b.clone(), + ct.c1.clone(), + ct.c2.clone(), + if use_fake_r { + fake_r.clone() + } else { + testcase.r.clone() + }, + if use_fake_e1 { + fake_e1.clone() + } else { + testcase.e1.clone() + }, + if use_fake_m { + fake_m.clone() + } else { + testcase.m.clone() + }, + if use_fake_e2 { + fake_e2.clone() + } else { + testcase.e2.clone() + }, + &public_param, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + for hash_mode in [PkeV2HashMode::BackwardCompat, PkeV2HashMode::Classical] { + let proof = prove_impl( + (&public_param, &public_commit), + &private_commit, + &testcase.metadata, + load, + &seed.to_le_bytes(), + hash_mode, + ProofSanityCheckMode::Panic, + ); + + let verify_metadata = if use_fake_metadata_verify { + &fake_metadata + } else { + &testcase.metadata + }; + + assert_eq!( + verify(&proof, (&public_param, &public_commit), verify_metadata).is_err(), + use_fake_e1 + || use_fake_e2 + || use_fake_r + || use_fake_m + || use_fake_metadata_verify + ); + } + } + } + } + fn prove_and_verify( testcase: &PkeTestcase, ct: &PkeTestCiphertext, crs: &PublicParams, load: ComputeLoad, seed: &[u8], + hash_mode: PkeV2HashMode, sanity_check_mode: ProofSanityCheckMode, ) -> VerificationResult { let (public_commit, private_commit) = commit( @@ -2226,6 +2364,7 @@ mod tests { &testcase.metadata, load, seed, + hash_mode, sanity_check_mode, ); @@ -2236,20 +2375,22 @@ mod tests { } } + #[allow(clippy::too_many_arguments)] fn assert_prove_and_verify( testcase: &PkeTestcase, ct: &PkeTestCiphertext, testcase_name: &str, crs: &PublicParams, seed: &[u8], + hash_mode: PkeV2HashMode, sanity_check_mode: ProofSanityCheckMode, expected_result: VerificationResult, ) { for load in [ComputeLoad::Proof, ComputeLoad::Verify] { assert_eq!( - prove_and_verify(testcase, ct, crs, load, seed, sanity_check_mode), + prove_and_verify(testcase, ct, crs, load, seed, hash_mode, sanity_check_mode), expected_result, - "Testcase {testcase_name} with load {load} failed" + "Testcase {testcase_name} {hash_mode:?} hash with load {load} failed" ) } } @@ -2515,6 +2656,7 @@ mod tests { &format!("{name}_crs"), &crs, &seed.to_le_bytes(), + PkeV2HashMode::Compact, ProofSanityCheckMode::Ignore, expected_result, ); @@ -2524,6 +2666,7 @@ mod tests { &format!("{name}_crs_max_k"), &crs_max_k, &seed.to_le_bytes(), + PkeV2HashMode::Compact, ProofSanityCheckMode::Ignore, expected_result, ); @@ -2617,6 +2760,7 @@ mod tests { test_name, &public_param, &seed.to_le_bytes(), + PkeV2HashMode::Compact, ProofSanityCheckMode::Panic, VerificationResult::Reject, ); @@ -2789,6 +2933,7 @@ mod tests { "testcase_bad_delta", &crs, &seed.to_le_bytes(), + PkeV2HashMode::Compact, ProofSanityCheckMode::Panic, VerificationResult::Reject, ); @@ -2830,6 +2975,7 @@ mod tests { &format!("testcase_big_params_{bound:?}"), &crs, &seed.to_le_bytes(), + PkeV2HashMode::Compact, ProofSanityCheckMode::Panic, VerificationResult::Accept, );