feat(zk): zk perf improvements

This commit is contained in:
sarah el kazdadi
2024-06-25 15:29:45 +02:00
committed by sarah
parent 818e480dac
commit 19e00c484b
2 changed files with 205 additions and 78 deletions

View File

@@ -13,9 +13,9 @@ description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryp
[dependencies]
ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" }
ark-ec = { package = "tfhe-ark-ec", version = "0.4.2" }
ark-ff = { package = "tfhe-ark-ff", version = "0.4.3" }
ark-poly = { package = "tfhe-ark-poly", version = "0.4.2" }
ark-ec = { package = "tfhe-ark-ec", version = "0.4.2", features = ["parallel"] }
ark-ff = { package = "tfhe-ark-ff", version = "0.4.3", features = ["parallel"] }
ark-poly = { package = "tfhe-ark-poly", version = "0.4.2", features = ["parallel"] }
ark-serialize = { version = "0.4.2" }
rand = "0.8.5"
rayon = "1.8.0"

View File

@@ -242,6 +242,96 @@ mod g2 {
.unwrap(),
}
}
// m is an intermediate variable that's used in both the curve point addition and pairing
// functions. we cache it since it requires a Zp division
// https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition
pub(crate) fn compute_m(self, other: G2Affine) -> Option<crate::curve_446::Fq2> {
let zero = crate::curve_446::Fq2::ZERO;
// in the context of elliptic curves, the point at infinity is the zero element of the
// group
if self.inner.infinity || other.inner.infinity {
return None;
}
if self == other {
let x = self.inner.x;
let y = self.inner.y;
if y == zero {
None
} else {
let xx = x.square();
Some((xx.double() + xx) / y.double())
}
} else {
let x1 = self.inner.x;
let y1 = self.inner.y;
let x2 = other.inner.x;
let y2 = other.inner.y;
let x_delta = x2 - x1;
let y_delta = y2 - y1;
if x_delta == zero {
None
} else {
Some(y_delta / x_delta)
}
}
}
pub(crate) fn double(self, m: Option<crate::curve_446::Fq2>) -> Self {
// in the context of elliptic curves, the point at infinity is the zero element of the
// group
if self.inner.infinity {
return self;
}
let mut result = self;
let x = self.inner.x;
let y = self.inner.y;
if let Some(m) = m {
let x3 = m.square() - x.double();
let y3 = m * (x - x3) - y;
(result.inner.x, result.inner.y) = (x3, y3);
} else {
result.inner.infinity = true;
}
result
}
pub(crate) fn add_unequal(self, other: G2Affine, m: Option<crate::curve_446::Fq2>) -> Self {
// in the context of elliptic curves, the point at infinity is the zero element of the
// group
if self.inner.infinity {
return other;
}
if other.inner.infinity {
return self;
}
let mut result = self;
let x1 = self.inner.x;
let y1 = self.inner.y;
let x2 = other.inner.x;
if let Some(m) = m {
let x3 = m.square() - x1 - x2;
let y3 = m * (x1 - x3) - y1;
(result.inner.x, result.inner.y) = (x3, y3);
} else {
result.inner.infinity = true;
}
result
}
}
#[derive(
@@ -373,9 +463,9 @@ mod g2 {
}
pub fn double(self) -> Self {
Self {
inner: self.inner.double(),
}
let mut this = self;
this.inner.double_in_place();
this
}
}
@@ -431,51 +521,79 @@ mod g2 {
}
mod gt {
use crate::curve_446::{Fq, Fq12, Fq2};
use super::*;
use ark_ec::bls12::Bls12Config;
use ark_ec::pairing::{MillerLoopOutput, Pairing};
use ark_ff::{CubicExtField, Fp12, Fp2, QuadExtField};
use ark_ff::{CubicExtField, QuadExtField};
type Bls = crate::curve_446::Bls12_446;
type Config = crate::curve_446::Config;
const ONE: Fp2<<Config as Bls12Config>::Fp2Config> = QuadExtField {
c0: MontFp!("1"),
c1: MontFp!("0"),
};
const ZERO: Fp2<<Config as Bls12Config>::Fp2Config> = QuadExtField {
const ZERO: Fq2 = QuadExtField {
c0: MontFp!("0"),
c1: MontFp!("0"),
};
const U1: Fp12<<Config as Bls12Config>::Fp12Config> = QuadExtField {
c0: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: ZERO,
},
c1: CubicExtField {
c0: ONE,
c1: ZERO,
c2: ZERO,
},
// computed by copying the result from
// let two: Fq = MontFp!("2"); println!("{}", two.inverse().unwrap()), which we can't compute in
// a const context;
const TWO_INV: Fq = {
MontFp!("86412351771428577990035638289747981121746346761394949218917418178192828331138736448451251370148591845087981000773214233672031082665302")
};
const U3: Fp12<<Config as Bls12Config>::Fp12Config> = QuadExtField {
c0: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: ZERO,
},
c1: CubicExtField {
c0: ZERO,
c1: ONE,
c2: ZERO,
},
const TWO_INV_MINUS_1: Fq = {
MontFp!("86412351771428577990035638289747981121746346761394949218917418178192828331138736448451251370148591845087981000773214233672031082665301")
};
const fn fp2_to_fp12(
x: Fp2<<Config as Bls12Config>::Fp2Config>,
) -> Fp12<<Config as Bls12Config>::Fp12Config> {
// the only non zero value in inv(U1) and inv(U3), which come from Olivier's equations.
const C: Fq2 = QuadExtField {
c0: TWO_INV,
c1: TWO_INV_MINUS_1,
};
fn fp2_mul_c(x: Fq2) -> Fq2 {
let x0_c0 = x.c0 * C.c0;
let x1_c0 = x.c1 * C.c0;
let x0_c1 = x0_c0 - x.c0;
let x1_c1 = x1_c0 - x.c1;
QuadExtField {
c0: x0_c0 - x1_c1,
c1: x0_c1 + x1_c0,
}
}
fn fp2_mul_u1_inv(x: Fq2) -> Fq12 {
QuadExtField {
c0: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: ZERO,
},
c1: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: fp2_mul_c(x),
},
}
}
fn fp2_mul_u3_inv(x: Fq2) -> Fq12 {
QuadExtField {
c0: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: ZERO,
},
c1: CubicExtField {
c0: ZERO,
c1: fp2_mul_c(x),
c2: ZERO,
},
}
}
const fn fp2_to_fp12(x: Fq2) -> Fq12 {
QuadExtField {
c0: CubicExtField {
c0: x,
@@ -490,52 +608,59 @@ mod gt {
}
}
const fn fp_to_fp12(
x: <Config as Bls12Config>::Fp,
) -> Fp12<<Config as Bls12Config>::Fp12Config> {
fp2_to_fp12(QuadExtField {
const fn fp_to_fp2(x: Fq) -> Fq2 {
QuadExtField {
c0: x,
c1: MontFp!("0"),
})
}
}
fn ate_tangent_ev(qt: G2, evpt: G1) -> Fp12<<Config as Bls12Config>::Fp12Config> {
let qt = qt.inner.into_affine();
let evpt = evpt.inner.into_affine();
const fn fp_to_fp12(x: Fq) -> Fq12 {
fp2_to_fp12(fp_to_fp2(x))
}
fn ate_tangent_ev(qt: G2Affine, evpt: G1Affine, m: Fq2) -> Fq12 {
let qt = qt.inner;
let evpt = evpt.inner;
let (xt, yt) = (qt.x, qt.y);
let (xe, ye) = (evpt.x, evpt.y);
let xt = fp2_to_fp12(xt);
let yt = fp2_to_fp12(yt);
let xe = fp_to_fp12(xe);
let ye = fp_to_fp12(ye);
let l = m;
let mut l_xe = l;
l_xe.c0 *= xe;
l_xe.c1 *= xe;
let three = fp_to_fp12(MontFp!("3"));
let two = fp_to_fp12(MontFp!("2"));
let mut r0 = fp_to_fp12(ye);
let r1 = fp2_mul_u1_inv(l_xe);
let r2 = fp2_mul_u3_inv(l * xt - yt);
let l = three * xt.square() / (two * yt);
ye - (l * xe) / U1 + (l * xt - yt) / U3
r0.c1.c1 = r2.c1.c1;
r0.c1.c2 = -r1.c1.c2;
r0
}
fn ate_line_ev(q1: G2, q2: G2, evpt: G1) -> Fp12<<Config as Bls12Config>::Fp12Config> {
let q1 = q1.inner.into_affine();
let q2 = q2.inner.into_affine();
let evpt = evpt.inner.into_affine();
fn ate_line_ev(q1: G2Affine, evpt: G1Affine, m: Fq2) -> Fq12 {
let q1 = q1.inner;
let evpt = evpt.inner;
let (x1, y1) = (q1.x, q1.y);
let (x2, y2) = (q2.x, q2.y);
let (xe, ye) = (evpt.x, evpt.y);
let x1 = fp2_to_fp12(x1);
let y1 = fp2_to_fp12(y1);
let x2 = fp2_to_fp12(x2);
let y2 = fp2_to_fp12(y2);
let xe = fp_to_fp12(xe);
let ye = fp_to_fp12(ye);
let l = m;
let mut l_xe = l;
l_xe.c0 *= xe;
l_xe.c1 *= xe;
let l = (y2 - y1) / (x2 - x1);
ye - (l * xe) / U1 + (l * x1 - y1) / U3
let mut r0 = fp_to_fp12(ye);
let r1 = fp2_mul_u1_inv(l * fp_to_fp2(xe));
let r2 = fp2_mul_u3_inv(l * x1 - y1);
r0.c1.c1 = r2.c1.c1;
r0.c1.c2 = -r1.c1.c2;
r0
}
#[allow(clippy::needless_range_loop)]
@@ -544,22 +669,24 @@ mod gt {
let t_bits = b"110000000001000001000000100000000000000000000000000000000100000000000000001";
let mut fk = fp_to_fp12(MontFp!("1"));
let p = p.normalize();
let q = q.normalize();
let mut qk = q;
for k in 1..t_log2 {
let lkk = ate_tangent_ev(qk, p);
qk = qk + qk;
let m = qk.compute_m(qk).unwrap();
let lkk = ate_tangent_ev(qk, p, m);
qk = qk.double(Some(m));
fk = fk.square() * lkk;
if t_bits[k] == b'1' {
assert_ne!(q, qk);
let lkp1 = if q != -qk {
ate_line_ev(q, qk, p)
} else {
fp_to_fp12(MontFp!("1"))
};
qk += q;
fk *= lkp1;
let m = q.compute_m(qk);
let new_qk = q.add_unequal(qk, m);
if !new_qk.inner.infinity {
fk *= ate_line_ev(q, p, m.unwrap());
}
qk = new_qk;
}
}
let mlo = MillerLoopOutput(fk);