From 87dbfdcd5ef9205623a2a36bfa630d8787f7c35d Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Mon, 18 Nov 2024 17:27:57 +0100 Subject: [PATCH] fix(zk): recompute B according to k in proof and use squared bounds This removes the need for sqrt operations also fix a proof slack was too big in v2 --- tfhe-zk-pok/src/backward_compatibility/mod.rs | 72 ++++++- tfhe-zk-pok/src/proofs/pke_v2.rs | 201 ++++++++++-------- tfhe-zk-pok/src/serialization.rs | 37 ++-- 3 files changed, 202 insertions(+), 108 deletions(-) diff --git a/tfhe-zk-pok/src/backward_compatibility/mod.rs b/tfhe-zk-pok/src/backward_compatibility/mod.rs index c54843549..81769bd91 100644 --- a/tfhe-zk-pok/src/backward_compatibility/mod.rs +++ b/tfhe-zk-pok/src/backward_compatibility/mod.rs @@ -1,12 +1,18 @@ +// to follow the notation of the paper +#![allow(non_snake_case)] + pub mod pke; pub mod pke_v2; +use std::convert::Infallible; use std::error::Error; use std::fmt::Display; -use tfhe_versionable::VersionsDispatch; +use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; use crate::curve_api::Curve; +use crate::four_squares::{isqrt, sqr}; +use crate::proofs::pke_v2::Bound; use crate::proofs::GroupElements; use crate::serialization::{ SerializableAffine, SerializableCubicExtField, SerializableFp, SerializableFp2, @@ -65,6 +71,65 @@ pub(crate) enum SerializableGroupElementsVersions { V0(SerializableGroupElements), } +#[derive(Version)] +pub struct SerializablePKEv2PublicParamsV0 { + pub(crate) g_lists: SerializableGroupElements, + pub(crate) D: usize, + pub n: usize, + pub d: usize, + pub k: usize, + pub B: u64, + pub B_r: u64, + pub B_bound: u64, + pub m_bound: usize, + pub q: u64, + pub t: u64, + pub msbs_zero_padding_bit_count: u64, + // We use Vec since serde does not support fixed size arrays of 256 elements + pub(crate) hash: Vec, + pub(crate) hash_R: Vec, + pub(crate) hash_t: Vec, + pub(crate) hash_w: Vec, + pub(crate) hash_agg: Vec, + pub(crate) hash_lmap: Vec, + pub(crate) hash_phi: Vec, + pub(crate) hash_xi: Vec, + pub(crate) hash_z: Vec, + pub(crate) hash_chi: Vec, +} + +impl Upgrade for SerializablePKEv2PublicParamsV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + let slack_factor = isqrt((self.d + self.k) as u128) as u64; + let B_inf = self.B / slack_factor; + Ok(SerializablePKEv2PublicParams { + g_lists: self.g_lists, + D: self.D, + n: self.n, + d: self.d, + k: self.k, + B_bound_squared: sqr(self.B_bound as u128), + B_inf, + q: self.q, + t: self.t, + msbs_zero_padding_bit_count: self.msbs_zero_padding_bit_count, + bound_type: Bound::CS, + hash: self.hash, + hash_R: self.hash_R, + hash_t: self.hash_t, + hash_w: self.hash_w, + hash_agg: self.hash_agg, + hash_lmap: self.hash_lmap, + hash_phi: self.hash_phi, + hash_xi: self.hash_xi, + hash_z: self.hash_z, + hash_chi: self.hash_chi, + }) + } +} + #[derive(VersionsDispatch)] pub enum SerializablePKEv2PublicParamsVersions { V0(SerializablePKEv2PublicParams), @@ -74,3 +139,8 @@ pub enum SerializablePKEv2PublicParamsVersions { pub enum SerializablePKEv1PublicParamsVersions { V0(SerializablePKEv1PublicParams), } + +#[derive(VersionsDispatch)] +pub enum BoundVersions { + V0(Bound), +} diff --git a/tfhe-zk-pok/src/proofs/pke_v2.rs b/tfhe-zk-pok/src/proofs/pke_v2.rs index d6a7a904f..4e829f7d3 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2.rs @@ -3,6 +3,7 @@ use super::*; use crate::backward_compatibility::pke_v2::{CompressedProofVersions, ProofVersions}; +use crate::backward_compatibility::BoundVersions; use crate::curve_api::{CompressedG1, CompressedG2}; use crate::four_squares::*; use crate::serialization::{ @@ -35,13 +36,13 @@ pub struct PublicParams { pub n: usize, pub d: usize, pub k: usize, - pub B: u64, - pub B_r: u64, - pub B_bound: u64, - pub m_bound: usize, + // We store the square of the bound to avoid rounding on sqrt operations + pub B_bound_squared: u128, + pub B_inf: u64, pub q: u64, pub t: u64, pub msbs_zero_padding_bit_count: u64, + pub bound_type: Bound, pub(crate) hash: [u8; HASH_METADATA_LEN_BYTES], pub(crate) hash_R: [u8; HASH_METADATA_LEN_BYTES], pub(crate) hash_t: [u8; HASH_METADATA_LEN_BYTES], @@ -72,13 +73,12 @@ where n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash, hash_R, hash_t, @@ -96,13 +96,12 @@ where n: *n, d: *d, k: *k, - B: *B, - B_r: *B_r, - B_bound: *B_bound, - m_bound: *m_bound, + B_inf: *B_inf, + B_bound_squared: *B_bound_squared, q: *q, t: *t, msbs_zero_padding_bit_count: *msbs_zero_padding_bit_count, + bound_type: *bound_type, hash: hash.to_vec(), hash_R: hash_R.to_vec(), hash_t: hash_t.to_vec(), @@ -123,13 +122,12 @@ where n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash, hash_R, hash_t, @@ -147,13 +145,12 @@ where n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash: try_vec_to_array(hash)?, hash_R: try_vec_to_array(hash_R)?, hash_t: try_vec_to_array(hash_t)?, @@ -175,11 +172,11 @@ impl PublicParams { g_hat_list: Vec>, d: usize, k: usize, - B: u64, + B_inf: u64, q: u64, t: u64, msbs_zero_padding_bit_count: u64, - bound: Bound, + bound_type: Bound, hash: [u8; HASH_METADATA_LEN_BYTES], hash_R: [u8; HASH_METADATA_LEN_BYTES], hash_t: [u8; HASH_METADATA_LEN_BYTES], @@ -191,21 +188,21 @@ impl PublicParams { hash_z: [u8; HASH_METADATA_LEN_BYTES], hash_chi: [u8; HASH_METADATA_LEN_BYTES], ) -> Self { - let (n, D, B_r, B_bound, m_bound) = - compute_crs_params(d, k, B, q, t, msbs_zero_padding_bit_count, bound); + let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k); + let (n, D, B_bound_squared, _) = + compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type); Self { g_lists: GroupElements::::from_vec(g_list, g_hat_list), D, n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash, hash_R, hash_t, @@ -220,7 +217,9 @@ impl PublicParams { } pub fn exclusive_max_noise(&self) -> u64 { - self.B + // Here we return the bound without slack because users aren't supposed to generate noise + // inside the slack + self.B_inf + 1 } /// Check if the crs can be used to generate or verify a proof @@ -478,72 +477,90 @@ pub struct PrivateCommit { __marker: PhantomData, } -#[derive(Copy, Clone, Debug)] +#[derive(PartialEq, Copy, Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(BoundVersions)] pub enum Bound { GHL, CS, } +fn ceil_ilog2(value: u128) -> u64 { + value.ilog2() as u64 + if value.is_power_of_two() { 0 } else { 1 } +} + pub fn compute_crs_params( d: usize, k: usize, - B: u64, - _q: u64, // we keep q here to make sure the API is consistent with [crs_gen] + B_squared: u128, t: u64, msbs_zero_padding_bit_count: u64, - bound: Bound, -) -> (usize, usize, u64, u64, usize) { - let B_r = d as u64 / 2 + 1; - let B_bound = { - let B = B as f64; - let d = d as f64; - let k = k as f64; + bound_type: Bound, +) -> (usize, usize, u128, usize) { + let mut B_bound_squared = { + (match bound_type { + // GHL factor is 9.75, 9.75**2 = 95.0625 + // Result is multiplied and divided by 10000 to avoid floating point operations + Bound::GHL => 950625, + Bound::CS => (2 * (d + k) + 4) as u128, + }) * (B_squared + (sqr(d + 2) * (d + k)) as u128 / 4) + }; - (match bound { - Bound::GHL => 9.75, - Bound::CS => f64::sqrt(2.0 * (d + k) + 4.0), - }) * f64::sqrt(sqr(B) + (sqr(d + 2.0) * (d + k)) / 4.0) + if bound_type == Bound::GHL { + B_bound_squared /= 10000; } - .ceil() as u64; - // Formula is round_up(1 + B_bound.ilog2()) so we convert it to +2 - let m_bound = 2 + B_bound.ilog2() as usize; + // Formula is round_up(1 + B_bound.ilog2()). + // Since we use B_bound_square, the log is divided by 2 + let m_bound = 1 + ceil_ilog2(B_bound_squared).div_ceil(2) as usize; // This is also the effective t for encryption let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count; + + // formula in Prove_pp: 2. let D = d + k * effective_t_for_decomposition.ilog2() as usize; let n = D + 128 * m_bound; - (n, D, B_r, B_bound, m_bound) + (n, D, B_bound_squared, m_bound) +} + +/// Convert a bound on the infinite norm of a vector into a bound on the square of the euclidean +/// norm. +/// +/// Use the relationship: `||x||_2 <= sqrt(dim)*||x||_inf`. Since we are only interested in the +/// squared bound, we avoid the sqrt by returning dim*(||x||_inf)^2. +fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 { + let norm_squared = sqr(B_inf) as u128; + + norm_squared * dim as u128 } /// Generates a CRS based on the bound the heuristic provided by the lemma 2 of the paper. pub fn crs_gen_ghl( d: usize, k: usize, - B: u64, + B_inf: u64, q: u64, t: u64, msbs_zero_padding_bit_count: u64, rng: &mut dyn RngCore, ) -> PublicParams { + let bound_type = Bound::GHL; let alpha = G::Zp::rand(rng); - let B = B * (isqrt((d + k) as _) as u64 + 1); - let (n, D, B_r, B_bound, m_bound) = - compute_crs_params(d, k, B, q, t, msbs_zero_padding_bit_count, Bound::GHL); + let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k); + let (n, D, B_bound_squared, _) = + compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type); PublicParams { g_lists: GroupElements::::new(n, alpha), D, n, d, k, - B, - B_r, - B_bound, - m_bound, + B_inf, + B_bound_squared, q, t, msbs_zero_padding_bit_count, + bound_type, hash: core::array::from_fn(|_| rng.gen()), hash_R: core::array::from_fn(|_| rng.gen()), hash_t: core::array::from_fn(|_| rng.gen()), @@ -562,29 +579,29 @@ pub fn crs_gen_ghl( pub fn crs_gen_cs( d: usize, k: usize, - B: u64, + B_inf: u64, q: u64, t: u64, msbs_zero_padding_bit_count: u64, rng: &mut dyn RngCore, ) -> PublicParams { + let bound_type = Bound::CS; let alpha = G::Zp::rand(rng); - let B = B * (isqrt((d + k) as _) as u64 + 1); - let (n, D, B_r, B_bound, m_bound) = - compute_crs_params(d, k, B, q, t, msbs_zero_padding_bit_count, Bound::CS); + let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k); + let (n, D, B_bound_squared, _) = + compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type); PublicParams { g_lists: GroupElements::::new(n, alpha), D, n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash: core::array::from_fn(|_| rng.gen()), hash_R: core::array::from_fn(|_| rng.gen()), hash_t: core::array::from_fn(|_| rng.gen()), @@ -681,13 +698,12 @@ fn prove_impl( n, d, k: k_max, - B, - B_r: _, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t: t_input, msbs_zero_padding_bit_count, + bound_type, ref hash, ref hash_R, ref hash_t, @@ -712,9 +728,16 @@ fn prove_impl( let decoded_q = decode_q(q); - // Recompute the D for our case if k is smaller than the k max - // formula in Prove_pp: 2. - let D = d + k * effective_cleartext_t.ilog2() as usize; + // Recompute some params for our case if k is smaller than the k max + let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k); + let (_, D, _, m_bound) = compute_crs_params( + d, + k, + B_squared, + t_input, + msbs_zero_padding_bit_count, + bound_type, + ); let e_sqr_norm = e1 .iter() @@ -725,9 +748,8 @@ fn prove_impl( if sanity_check_mode == ProofSanityCheckMode::Panic { assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, D, D_max); assert!( - sqr(B as u128) >= e_sqr_norm, - "squared norm of error ({e_sqr_norm}) exceeds threshold ({})", - sqr(B as u128) + B_squared >= e_sqr_norm, + "squared norm of error ({e_sqr_norm}) exceeds threshold ({B_squared})", ); } @@ -761,7 +783,7 @@ fn prove_impl( ) .collect::>(); - let v = four_squares(sqr(B as u128) - e_sqr_norm).map(|v| v as i64); + let v = four_squares(B_squared - e_sqr_norm).map(|v| v as i64); let e1_zp = &*e1 .iter() @@ -813,7 +835,7 @@ fn prove_impl( let x_bytes = &*[ q.to_le_bytes().as_slice(), (d as u64).to_le_bytes().as_slice(), - B.to_le_bytes().as_slice(), + B_squared.to_le_bytes().as_slice(), t_input.to_le_bytes().as_slice(), msbs_zero_padding_bit_count.to_le_bytes().as_slice(), &*a.iter() @@ -892,7 +914,7 @@ fn prove_impl( _ => unreachable!(), }); if sanity_check_mode == ProofSanityCheckMode::Panic { - assert!(acc.unsigned_abs() <= B_bound as u128); + assert!(sqr(acc) as u128 <= B_bound_squared); } acc as i64 }) @@ -1357,7 +1379,7 @@ fn prove_impl( } let mut P_pi = poly_0; if P_pi.len() > n + 1 { - P_pi[n + 1] -= delta_theta * t_theta + delta_l * sqr(G::Zp::from_u64(B)); + P_pi[n + 1] -= delta_theta * t_theta + delta_l * G::Zp::from_u128(B_squared); } let pi = if P_pi.is_empty() { @@ -1807,13 +1829,12 @@ pub fn verify( n, d, k: k_max, - B, - B_r: _, - B_bound: _, - m_bound: m, + B_bound_squared: _, + B_inf, q, t: t_input, msbs_zero_padding_bit_count, + bound_type, ref hash, ref hash_R, ref hash_t, @@ -1843,10 +1864,18 @@ pub fn verify( } let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count; + let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k); + let (_, D, _, m_bound) = compute_crs_params( + d, + k, + B_squared, + t_input, + msbs_zero_padding_bit_count, + bound_type, + ); + + let m = m_bound; - // Recompute the D for our case if k is smaller than the k max - // formula in Prove_pp: 2. - let D = d + k * effective_cleartext_t.ilog2() as usize; if D > D_max { return Err(()); } @@ -1869,7 +1898,7 @@ pub fn verify( let x_bytes = &*[ q.to_le_bytes().as_slice(), (d as u64).to_le_bytes().as_slice(), - B.to_le_bytes().as_slice(), + B_squared.to_le_bytes().as_slice(), t_input.to_le_bytes().as_slice(), msbs_zero_padding_bit_count.to_le_bytes().as_slice(), &*a.iter() @@ -2166,7 +2195,7 @@ pub fn verify( G::G1::projective(g_list[0]), G::G2::projective(g_hat_list[n - 1]), ) - .mul_scalar(delta_theta * t_theta + delta_l * sqr(G::Zp::from_u64(B))); + .mul_scalar(delta_theta * t_theta + delta_l * G::Zp::from_u128(B_squared)); lhs0 + lhs1 + lhs2 - lhs3 - lhs4 - lhs5 - lhs6 }; diff --git a/tfhe-zk-pok/src/serialization.rs b/tfhe-zk-pok/src/serialization.rs index 9699bd0b2..bda1a09f5 100644 --- a/tfhe-zk-pok/src/serialization.rs +++ b/tfhe-zk-pok/src/serialization.rs @@ -17,7 +17,7 @@ use tfhe_versionable::Versionize; use crate::curve_api::{Curve, CurveGroupOps}; use crate::proofs::pke::PublicParams as PKEv1PublicParams; -use crate::proofs::pke_v2::PublicParams as PKEv2PublicParams; +use crate::proofs::pke_v2::{Bound, PublicParams as PKEv2PublicParams}; use crate::proofs::GroupElements; /// Error returned when a conversion from a vec to a fixed size array failed because the vec size is @@ -397,13 +397,12 @@ pub struct SerializablePKEv2PublicParams { pub n: usize, pub d: usize, pub k: usize, - pub B: u64, - pub B_r: u64, - pub B_bound: u64, - pub m_bound: usize, + pub B_bound_squared: u128, + pub B_inf: u64, pub q: u64, pub t: u64, pub msbs_zero_padding_bit_count: u64, + pub bound_type: Bound, // We use Vec since serde does not support fixed size arrays of 256 elements pub(crate) hash: Vec, pub(crate) hash_R: Vec, @@ -428,13 +427,12 @@ where n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash, hash_R, hash_t, @@ -452,13 +450,12 @@ where n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash: hash.to_vec(), hash_R: hash_R.to_vec(), hash_t: hash_t.to_vec(), @@ -487,13 +484,12 @@ where n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash, hash_R, hash_t, @@ -511,13 +507,12 @@ where n, d, k, - B, - B_r, - B_bound, - m_bound, + B_bound_squared, + B_inf, q, t, msbs_zero_padding_bit_count, + bound_type, hash: try_vec_to_array(hash)?, hash_R: try_vec_to_array(hash_R)?, hash_t: try_vec_to_array(hash_t)?,