fix(zk): recompute B according to k in proof and use squared bounds

This removes the need for sqrt operations
also fix a proof slack was too big in v2
This commit is contained in:
Nicolas Sarlin
2024-11-18 17:27:57 +01:00
committed by Nicolas Sarlin
parent 770ae22bb6
commit 87dbfdcd5e
3 changed files with 202 additions and 108 deletions

View File

@@ -1,12 +1,18 @@
// to follow the notation of the paper
#![allow(non_snake_case)]
pub mod pke;
pub mod pke_v2;
use std::convert::Infallible;
use std::error::Error;
use std::fmt::Display;
use tfhe_versionable::VersionsDispatch;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use crate::curve_api::Curve;
use crate::four_squares::{isqrt, sqr};
use crate::proofs::pke_v2::Bound;
use crate::proofs::GroupElements;
use crate::serialization::{
SerializableAffine, SerializableCubicExtField, SerializableFp, SerializableFp2,
@@ -65,6 +71,65 @@ pub(crate) enum SerializableGroupElementsVersions {
V0(SerializableGroupElements),
}
#[derive(Version)]
pub struct SerializablePKEv2PublicParamsV0 {
pub(crate) g_lists: SerializableGroupElements,
pub(crate) D: usize,
pub n: usize,
pub d: usize,
pub k: usize,
pub B: u64,
pub B_r: u64,
pub B_bound: u64,
pub m_bound: usize,
pub q: u64,
pub t: u64,
pub msbs_zero_padding_bit_count: u64,
// We use Vec<u8> since serde does not support fixed size arrays of 256 elements
pub(crate) hash: Vec<u8>,
pub(crate) hash_R: Vec<u8>,
pub(crate) hash_t: Vec<u8>,
pub(crate) hash_w: Vec<u8>,
pub(crate) hash_agg: Vec<u8>,
pub(crate) hash_lmap: Vec<u8>,
pub(crate) hash_phi: Vec<u8>,
pub(crate) hash_xi: Vec<u8>,
pub(crate) hash_z: Vec<u8>,
pub(crate) hash_chi: Vec<u8>,
}
impl Upgrade<SerializablePKEv2PublicParams> for SerializablePKEv2PublicParamsV0 {
type Error = Infallible;
fn upgrade(self) -> Result<SerializablePKEv2PublicParams, Self::Error> {
let slack_factor = isqrt((self.d + self.k) as u128) as u64;
let B_inf = self.B / slack_factor;
Ok(SerializablePKEv2PublicParams {
g_lists: self.g_lists,
D: self.D,
n: self.n,
d: self.d,
k: self.k,
B_bound_squared: sqr(self.B_bound as u128),
B_inf,
q: self.q,
t: self.t,
msbs_zero_padding_bit_count: self.msbs_zero_padding_bit_count,
bound_type: Bound::CS,
hash: self.hash,
hash_R: self.hash_R,
hash_t: self.hash_t,
hash_w: self.hash_w,
hash_agg: self.hash_agg,
hash_lmap: self.hash_lmap,
hash_phi: self.hash_phi,
hash_xi: self.hash_xi,
hash_z: self.hash_z,
hash_chi: self.hash_chi,
})
}
}
#[derive(VersionsDispatch)]
pub enum SerializablePKEv2PublicParamsVersions {
V0(SerializablePKEv2PublicParams),
@@ -74,3 +139,8 @@ pub enum SerializablePKEv2PublicParamsVersions {
pub enum SerializablePKEv1PublicParamsVersions {
V0(SerializablePKEv1PublicParams),
}
#[derive(VersionsDispatch)]
pub enum BoundVersions {
V0(Bound),
}

View File

@@ -3,6 +3,7 @@
use super::*;
use crate::backward_compatibility::pke_v2::{CompressedProofVersions, ProofVersions};
use crate::backward_compatibility::BoundVersions;
use crate::curve_api::{CompressedG1, CompressedG2};
use crate::four_squares::*;
use crate::serialization::{
@@ -35,13 +36,13 @@ pub struct PublicParams<G: Curve> {
pub n: usize,
pub d: usize,
pub k: usize,
pub B: u64,
pub B_r: u64,
pub B_bound: u64,
pub m_bound: usize,
// We store the square of the bound to avoid rounding on sqrt operations
pub B_bound_squared: u128,
pub B_inf: u64,
pub q: u64,
pub t: u64,
pub msbs_zero_padding_bit_count: u64,
pub bound_type: Bound,
pub(crate) hash: [u8; HASH_METADATA_LEN_BYTES],
pub(crate) hash_R: [u8; HASH_METADATA_LEN_BYTES],
pub(crate) hash_t: [u8; HASH_METADATA_LEN_BYTES],
@@ -72,13 +73,12 @@ where
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash,
hash_R,
hash_t,
@@ -96,13 +96,12 @@ where
n: *n,
d: *d,
k: *k,
B: *B,
B_r: *B_r,
B_bound: *B_bound,
m_bound: *m_bound,
B_inf: *B_inf,
B_bound_squared: *B_bound_squared,
q: *q,
t: *t,
msbs_zero_padding_bit_count: *msbs_zero_padding_bit_count,
bound_type: *bound_type,
hash: hash.to_vec(),
hash_R: hash_R.to_vec(),
hash_t: hash_t.to_vec(),
@@ -123,13 +122,12 @@ where
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash,
hash_R,
hash_t,
@@ -147,13 +145,12 @@ where
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash: try_vec_to_array(hash)?,
hash_R: try_vec_to_array(hash_R)?,
hash_t: try_vec_to_array(hash_t)?,
@@ -175,11 +172,11 @@ impl<G: Curve> PublicParams<G> {
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
d: usize,
k: usize,
B: u64,
B_inf: u64,
q: u64,
t: u64,
msbs_zero_padding_bit_count: u64,
bound: Bound,
bound_type: Bound,
hash: [u8; HASH_METADATA_LEN_BYTES],
hash_R: [u8; HASH_METADATA_LEN_BYTES],
hash_t: [u8; HASH_METADATA_LEN_BYTES],
@@ -191,21 +188,21 @@ impl<G: Curve> PublicParams<G> {
hash_z: [u8; HASH_METADATA_LEN_BYTES],
hash_chi: [u8; HASH_METADATA_LEN_BYTES],
) -> Self {
let (n, D, B_r, B_bound, m_bound) =
compute_crs_params(d, k, B, q, t, msbs_zero_padding_bit_count, bound);
let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
let (n, D, B_bound_squared, _) =
compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type);
Self {
g_lists: GroupElements::<G>::from_vec(g_list, g_hat_list),
D,
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash,
hash_R,
hash_t,
@@ -220,7 +217,9 @@ impl<G: Curve> PublicParams<G> {
}
pub fn exclusive_max_noise(&self) -> u64 {
self.B
// Here we return the bound without slack because users aren't supposed to generate noise
// inside the slack
self.B_inf + 1
}
/// Check if the crs can be used to generate or verify a proof
@@ -478,72 +477,90 @@ pub struct PrivateCommit<G: Curve> {
__marker: PhantomData<G>,
}
#[derive(Copy, Clone, Debug)]
#[derive(PartialEq, Copy, Clone, Debug, Serialize, Deserialize, Versionize)]
#[versionize(BoundVersions)]
pub enum Bound {
GHL,
CS,
}
fn ceil_ilog2(value: u128) -> u64 {
value.ilog2() as u64 + if value.is_power_of_two() { 0 } else { 1 }
}
pub fn compute_crs_params(
d: usize,
k: usize,
B: u64,
_q: u64, // we keep q here to make sure the API is consistent with [crs_gen]
B_squared: u128,
t: u64,
msbs_zero_padding_bit_count: u64,
bound: Bound,
) -> (usize, usize, u64, u64, usize) {
let B_r = d as u64 / 2 + 1;
let B_bound = {
let B = B as f64;
let d = d as f64;
let k = k as f64;
bound_type: Bound,
) -> (usize, usize, u128, usize) {
let mut B_bound_squared = {
(match bound_type {
// GHL factor is 9.75, 9.75**2 = 95.0625
// Result is multiplied and divided by 10000 to avoid floating point operations
Bound::GHL => 950625,
Bound::CS => (2 * (d + k) + 4) as u128,
}) * (B_squared + (sqr(d + 2) * (d + k)) as u128 / 4)
};
(match bound {
Bound::GHL => 9.75,
Bound::CS => f64::sqrt(2.0 * (d + k) + 4.0),
}) * f64::sqrt(sqr(B) + (sqr(d + 2.0) * (d + k)) / 4.0)
if bound_type == Bound::GHL {
B_bound_squared /= 10000;
}
.ceil() as u64;
// Formula is round_up(1 + B_bound.ilog2()) so we convert it to +2
let m_bound = 2 + B_bound.ilog2() as usize;
// Formula is round_up(1 + B_bound.ilog2()).
// Since we use B_bound_square, the log is divided by 2
let m_bound = 1 + ceil_ilog2(B_bound_squared).div_ceil(2) as usize;
// This is also the effective t for encryption
let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count;
// formula in Prove_pp: 2.
let D = d + k * effective_t_for_decomposition.ilog2() as usize;
let n = D + 128 * m_bound;
(n, D, B_r, B_bound, m_bound)
(n, D, B_bound_squared, m_bound)
}
/// Convert a bound on the infinite norm of a vector into a bound on the square of the euclidean
/// norm.
///
/// Use the relationship: `||x||_2 <= sqrt(dim)*||x||_inf`. Since we are only interested in the
/// squared bound, we avoid the sqrt by returning dim*(||x||_inf)^2.
fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 {
let norm_squared = sqr(B_inf) as u128;
norm_squared * dim as u128
}
/// Generates a CRS based on the bound the heuristic provided by the lemma 2 of the paper.
pub fn crs_gen_ghl<G: Curve>(
d: usize,
k: usize,
B: u64,
B_inf: u64,
q: u64,
t: u64,
msbs_zero_padding_bit_count: u64,
rng: &mut dyn RngCore,
) -> PublicParams<G> {
let bound_type = Bound::GHL;
let alpha = G::Zp::rand(rng);
let B = B * (isqrt((d + k) as _) as u64 + 1);
let (n, D, B_r, B_bound, m_bound) =
compute_crs_params(d, k, B, q, t, msbs_zero_padding_bit_count, Bound::GHL);
let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
let (n, D, B_bound_squared, _) =
compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type);
PublicParams {
g_lists: GroupElements::<G>::new(n, alpha),
D,
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_inf,
B_bound_squared,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash: core::array::from_fn(|_| rng.gen()),
hash_R: core::array::from_fn(|_| rng.gen()),
hash_t: core::array::from_fn(|_| rng.gen()),
@@ -562,29 +579,29 @@ pub fn crs_gen_ghl<G: Curve>(
pub fn crs_gen_cs<G: Curve>(
d: usize,
k: usize,
B: u64,
B_inf: u64,
q: u64,
t: u64,
msbs_zero_padding_bit_count: u64,
rng: &mut dyn RngCore,
) -> PublicParams<G> {
let bound_type = Bound::CS;
let alpha = G::Zp::rand(rng);
let B = B * (isqrt((d + k) as _) as u64 + 1);
let (n, D, B_r, B_bound, m_bound) =
compute_crs_params(d, k, B, q, t, msbs_zero_padding_bit_count, Bound::CS);
let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
let (n, D, B_bound_squared, _) =
compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type);
PublicParams {
g_lists: GroupElements::<G>::new(n, alpha),
D,
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash: core::array::from_fn(|_| rng.gen()),
hash_R: core::array::from_fn(|_| rng.gen()),
hash_t: core::array::from_fn(|_| rng.gen()),
@@ -681,13 +698,12 @@ fn prove_impl<G: Curve>(
n,
d,
k: k_max,
B,
B_r: _,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t: t_input,
msbs_zero_padding_bit_count,
bound_type,
ref hash,
ref hash_R,
ref hash_t,
@@ -712,9 +728,16 @@ fn prove_impl<G: Curve>(
let decoded_q = decode_q(q);
// Recompute the D for our case if k is smaller than the k max
// formula in Prove_pp: 2.
let D = d + k * effective_cleartext_t.ilog2() as usize;
// Recompute some params for our case if k is smaller than the k max
let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
let (_, D, _, m_bound) = compute_crs_params(
d,
k,
B_squared,
t_input,
msbs_zero_padding_bit_count,
bound_type,
);
let e_sqr_norm = e1
.iter()
@@ -725,9 +748,8 @@ fn prove_impl<G: Curve>(
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, D, D_max);
assert!(
sqr(B as u128) >= e_sqr_norm,
"squared norm of error ({e_sqr_norm}) exceeds threshold ({})",
sqr(B as u128)
B_squared >= e_sqr_norm,
"squared norm of error ({e_sqr_norm}) exceeds threshold ({B_squared})",
);
}
@@ -761,7 +783,7 @@ fn prove_impl<G: Curve>(
)
.collect::<Box<[_]>>();
let v = four_squares(sqr(B as u128) - e_sqr_norm).map(|v| v as i64);
let v = four_squares(B_squared - e_sqr_norm).map(|v| v as i64);
let e1_zp = &*e1
.iter()
@@ -813,7 +835,7 @@ fn prove_impl<G: Curve>(
let x_bytes = &*[
q.to_le_bytes().as_slice(),
(d as u64).to_le_bytes().as_slice(),
B.to_le_bytes().as_slice(),
B_squared.to_le_bytes().as_slice(),
t_input.to_le_bytes().as_slice(),
msbs_zero_padding_bit_count.to_le_bytes().as_slice(),
&*a.iter()
@@ -892,7 +914,7 @@ fn prove_impl<G: Curve>(
_ => unreachable!(),
});
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert!(acc.unsigned_abs() <= B_bound as u128);
assert!(sqr(acc) as u128 <= B_bound_squared);
}
acc as i64
})
@@ -1357,7 +1379,7 @@ fn prove_impl<G: Curve>(
}
let mut P_pi = poly_0;
if P_pi.len() > n + 1 {
P_pi[n + 1] -= delta_theta * t_theta + delta_l * sqr(G::Zp::from_u64(B));
P_pi[n + 1] -= delta_theta * t_theta + delta_l * G::Zp::from_u128(B_squared);
}
let pi = if P_pi.is_empty() {
@@ -1807,13 +1829,12 @@ pub fn verify<G: Curve>(
n,
d,
k: k_max,
B,
B_r: _,
B_bound: _,
m_bound: m,
B_bound_squared: _,
B_inf,
q,
t: t_input,
msbs_zero_padding_bit_count,
bound_type,
ref hash,
ref hash_R,
ref hash_t,
@@ -1843,10 +1864,18 @@ pub fn verify<G: Curve>(
}
let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count;
let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
let (_, D, _, m_bound) = compute_crs_params(
d,
k,
B_squared,
t_input,
msbs_zero_padding_bit_count,
bound_type,
);
let m = m_bound;
// Recompute the D for our case if k is smaller than the k max
// formula in Prove_pp: 2.
let D = d + k * effective_cleartext_t.ilog2() as usize;
if D > D_max {
return Err(());
}
@@ -1869,7 +1898,7 @@ pub fn verify<G: Curve>(
let x_bytes = &*[
q.to_le_bytes().as_slice(),
(d as u64).to_le_bytes().as_slice(),
B.to_le_bytes().as_slice(),
B_squared.to_le_bytes().as_slice(),
t_input.to_le_bytes().as_slice(),
msbs_zero_padding_bit_count.to_le_bytes().as_slice(),
&*a.iter()
@@ -2166,7 +2195,7 @@ pub fn verify<G: Curve>(
G::G1::projective(g_list[0]),
G::G2::projective(g_hat_list[n - 1]),
)
.mul_scalar(delta_theta * t_theta + delta_l * sqr(G::Zp::from_u64(B)));
.mul_scalar(delta_theta * t_theta + delta_l * G::Zp::from_u128(B_squared));
lhs0 + lhs1 + lhs2 - lhs3 - lhs4 - lhs5 - lhs6
};

View File

@@ -17,7 +17,7 @@ use tfhe_versionable::Versionize;
use crate::curve_api::{Curve, CurveGroupOps};
use crate::proofs::pke::PublicParams as PKEv1PublicParams;
use crate::proofs::pke_v2::PublicParams as PKEv2PublicParams;
use crate::proofs::pke_v2::{Bound, PublicParams as PKEv2PublicParams};
use crate::proofs::GroupElements;
/// Error returned when a conversion from a vec to a fixed size array failed because the vec size is
@@ -397,13 +397,12 @@ pub struct SerializablePKEv2PublicParams {
pub n: usize,
pub d: usize,
pub k: usize,
pub B: u64,
pub B_r: u64,
pub B_bound: u64,
pub m_bound: usize,
pub B_bound_squared: u128,
pub B_inf: u64,
pub q: u64,
pub t: u64,
pub msbs_zero_padding_bit_count: u64,
pub bound_type: Bound,
// We use Vec<u8> since serde does not support fixed size arrays of 256 elements
pub(crate) hash: Vec<u8>,
pub(crate) hash_R: Vec<u8>,
@@ -428,13 +427,12 @@ where
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash,
hash_R,
hash_t,
@@ -452,13 +450,12 @@ where
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash: hash.to_vec(),
hash_R: hash_R.to_vec(),
hash_t: hash_t.to_vec(),
@@ -487,13 +484,12 @@ where
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash,
hash_R,
hash_t,
@@ -511,13 +507,12 @@ where
n,
d,
k,
B,
B_r,
B_bound,
m_bound,
B_bound_squared,
B_inf,
q,
t,
msbs_zero_padding_bit_count,
bound_type,
hash: try_vec_to_array(hash)?,
hash_R: try_vec_to_array(hash_R)?,
hash_t: try_vec_to_array(hash_t)?,