feat(zk): zk perf improvements

This commit is contained in:
sarah el kazdadi
2024-06-25 15:29:45 +02:00
committed by sarah
parent 818e480dac
commit 19e00c484b
2 changed files with 205 additions and 78 deletions

View File

@@ -13,9 +13,9 @@ description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryp
[dependencies] [dependencies]
ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" } ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" }
ark-ec = { package = "tfhe-ark-ec", version = "0.4.2" } ark-ec = { package = "tfhe-ark-ec", version = "0.4.2", features = ["parallel"] }
ark-ff = { package = "tfhe-ark-ff", version = "0.4.3" } ark-ff = { package = "tfhe-ark-ff", version = "0.4.3", features = ["parallel"] }
ark-poly = { package = "tfhe-ark-poly", version = "0.4.2" } ark-poly = { package = "tfhe-ark-poly", version = "0.4.2", features = ["parallel"] }
ark-serialize = { version = "0.4.2" } ark-serialize = { version = "0.4.2" }
rand = "0.8.5" rand = "0.8.5"
rayon = "1.8.0" rayon = "1.8.0"

View File

@@ -242,6 +242,96 @@ mod g2 {
.unwrap(), .unwrap(),
} }
} }
// m is an intermediate variable that's used in both the curve point addition and pairing
// functions. we cache it since it requires a Zp division
// https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition
pub(crate) fn compute_m(self, other: G2Affine) -> Option<crate::curve_446::Fq2> {
let zero = crate::curve_446::Fq2::ZERO;
// in the context of elliptic curves, the point at infinity is the zero element of the
// group
if self.inner.infinity || other.inner.infinity {
return None;
}
if self == other {
let x = self.inner.x;
let y = self.inner.y;
if y == zero {
None
} else {
let xx = x.square();
Some((xx.double() + xx) / y.double())
}
} else {
let x1 = self.inner.x;
let y1 = self.inner.y;
let x2 = other.inner.x;
let y2 = other.inner.y;
let x_delta = x2 - x1;
let y_delta = y2 - y1;
if x_delta == zero {
None
} else {
Some(y_delta / x_delta)
}
}
}
pub(crate) fn double(self, m: Option<crate::curve_446::Fq2>) -> Self {
// in the context of elliptic curves, the point at infinity is the zero element of the
// group
if self.inner.infinity {
return self;
}
let mut result = self;
let x = self.inner.x;
let y = self.inner.y;
if let Some(m) = m {
let x3 = m.square() - x.double();
let y3 = m * (x - x3) - y;
(result.inner.x, result.inner.y) = (x3, y3);
} else {
result.inner.infinity = true;
}
result
}
pub(crate) fn add_unequal(self, other: G2Affine, m: Option<crate::curve_446::Fq2>) -> Self {
// in the context of elliptic curves, the point at infinity is the zero element of the
// group
if self.inner.infinity {
return other;
}
if other.inner.infinity {
return self;
}
let mut result = self;
let x1 = self.inner.x;
let y1 = self.inner.y;
let x2 = other.inner.x;
if let Some(m) = m {
let x3 = m.square() - x1 - x2;
let y3 = m * (x1 - x3) - y1;
(result.inner.x, result.inner.y) = (x3, y3);
} else {
result.inner.infinity = true;
}
result
}
} }
#[derive( #[derive(
@@ -373,9 +463,9 @@ mod g2 {
} }
pub fn double(self) -> Self { pub fn double(self) -> Self {
Self { let mut this = self;
inner: self.inner.double(), this.inner.double_in_place();
} this
} }
} }
@@ -431,51 +521,79 @@ mod g2 {
} }
mod gt { mod gt {
use crate::curve_446::{Fq, Fq12, Fq2};
use super::*; use super::*;
use ark_ec::bls12::Bls12Config;
use ark_ec::pairing::{MillerLoopOutput, Pairing}; use ark_ec::pairing::{MillerLoopOutput, Pairing};
use ark_ff::{CubicExtField, Fp12, Fp2, QuadExtField}; use ark_ff::{CubicExtField, QuadExtField};
type Bls = crate::curve_446::Bls12_446; type Bls = crate::curve_446::Bls12_446;
type Config = crate::curve_446::Config;
const ONE: Fp2<<Config as Bls12Config>::Fp2Config> = QuadExtField { const ZERO: Fq2 = QuadExtField {
c0: MontFp!("1"),
c1: MontFp!("0"),
};
const ZERO: Fp2<<Config as Bls12Config>::Fp2Config> = QuadExtField {
c0: MontFp!("0"), c0: MontFp!("0"),
c1: MontFp!("0"), c1: MontFp!("0"),
}; };
const U1: Fp12<<Config as Bls12Config>::Fp12Config> = QuadExtField { // computed by copying the result from
c0: CubicExtField { // let two: Fq = MontFp!("2"); println!("{}", two.inverse().unwrap()), which we can't compute in
c0: ZERO, // a const context;
c1: ZERO, const TWO_INV: Fq = {
c2: ZERO, MontFp!("86412351771428577990035638289747981121746346761394949218917418178192828331138736448451251370148591845087981000773214233672031082665302")
},
c1: CubicExtField {
c0: ONE,
c1: ZERO,
c2: ZERO,
},
}; };
const U3: Fp12<<Config as Bls12Config>::Fp12Config> = QuadExtField { const TWO_INV_MINUS_1: Fq = {
c0: CubicExtField { MontFp!("86412351771428577990035638289747981121746346761394949218917418178192828331138736448451251370148591845087981000773214233672031082665301")
c0: ZERO,
c1: ZERO,
c2: ZERO,
},
c1: CubicExtField {
c0: ZERO,
c1: ONE,
c2: ZERO,
},
}; };
const fn fp2_to_fp12( // the only non zero value in inv(U1) and inv(U3), which come from Olivier's equations.
x: Fp2<<Config as Bls12Config>::Fp2Config>, const C: Fq2 = QuadExtField {
) -> Fp12<<Config as Bls12Config>::Fp12Config> { c0: TWO_INV,
c1: TWO_INV_MINUS_1,
};
fn fp2_mul_c(x: Fq2) -> Fq2 {
let x0_c0 = x.c0 * C.c0;
let x1_c0 = x.c1 * C.c0;
let x0_c1 = x0_c0 - x.c0;
let x1_c1 = x1_c0 - x.c1;
QuadExtField {
c0: x0_c0 - x1_c1,
c1: x0_c1 + x1_c0,
}
}
fn fp2_mul_u1_inv(x: Fq2) -> Fq12 {
QuadExtField {
c0: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: ZERO,
},
c1: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: fp2_mul_c(x),
},
}
}
fn fp2_mul_u3_inv(x: Fq2) -> Fq12 {
QuadExtField {
c0: CubicExtField {
c0: ZERO,
c1: ZERO,
c2: ZERO,
},
c1: CubicExtField {
c0: ZERO,
c1: fp2_mul_c(x),
c2: ZERO,
},
}
}
const fn fp2_to_fp12(x: Fq2) -> Fq12 {
QuadExtField { QuadExtField {
c0: CubicExtField { c0: CubicExtField {
c0: x, c0: x,
@@ -490,52 +608,59 @@ mod gt {
} }
} }
const fn fp_to_fp12( const fn fp_to_fp2(x: Fq) -> Fq2 {
x: <Config as Bls12Config>::Fp, QuadExtField {
) -> Fp12<<Config as Bls12Config>::Fp12Config> {
fp2_to_fp12(QuadExtField {
c0: x, c0: x,
c1: MontFp!("0"), c1: MontFp!("0"),
}) }
} }
fn ate_tangent_ev(qt: G2, evpt: G1) -> Fp12<<Config as Bls12Config>::Fp12Config> { const fn fp_to_fp12(x: Fq) -> Fq12 {
let qt = qt.inner.into_affine(); fp2_to_fp12(fp_to_fp2(x))
let evpt = evpt.inner.into_affine(); }
fn ate_tangent_ev(qt: G2Affine, evpt: G1Affine, m: Fq2) -> Fq12 {
let qt = qt.inner;
let evpt = evpt.inner;
let (xt, yt) = (qt.x, qt.y); let (xt, yt) = (qt.x, qt.y);
let (xe, ye) = (evpt.x, evpt.y); let (xe, ye) = (evpt.x, evpt.y);
let xt = fp2_to_fp12(xt); let l = m;
let yt = fp2_to_fp12(yt); let mut l_xe = l;
let xe = fp_to_fp12(xe); l_xe.c0 *= xe;
let ye = fp_to_fp12(ye); l_xe.c1 *= xe;
let three = fp_to_fp12(MontFp!("3")); let mut r0 = fp_to_fp12(ye);
let two = fp_to_fp12(MontFp!("2")); let r1 = fp2_mul_u1_inv(l_xe);
let r2 = fp2_mul_u3_inv(l * xt - yt);
let l = three * xt.square() / (two * yt); r0.c1.c1 = r2.c1.c1;
ye - (l * xe) / U1 + (l * xt - yt) / U3 r0.c1.c2 = -r1.c1.c2;
r0
} }
fn ate_line_ev(q1: G2, q2: G2, evpt: G1) -> Fp12<<Config as Bls12Config>::Fp12Config> { fn ate_line_ev(q1: G2Affine, evpt: G1Affine, m: Fq2) -> Fq12 {
let q1 = q1.inner.into_affine(); let q1 = q1.inner;
let q2 = q2.inner.into_affine(); let evpt = evpt.inner;
let evpt = evpt.inner.into_affine();
let (x1, y1) = (q1.x, q1.y); let (x1, y1) = (q1.x, q1.y);
let (x2, y2) = (q2.x, q2.y);
let (xe, ye) = (evpt.x, evpt.y); let (xe, ye) = (evpt.x, evpt.y);
let x1 = fp2_to_fp12(x1); let l = m;
let y1 = fp2_to_fp12(y1); let mut l_xe = l;
let x2 = fp2_to_fp12(x2); l_xe.c0 *= xe;
let y2 = fp2_to_fp12(y2); l_xe.c1 *= xe;
let xe = fp_to_fp12(xe);
let ye = fp_to_fp12(ye);
let l = (y2 - y1) / (x2 - x1); let mut r0 = fp_to_fp12(ye);
ye - (l * xe) / U1 + (l * x1 - y1) / U3 let r1 = fp2_mul_u1_inv(l * fp_to_fp2(xe));
let r2 = fp2_mul_u3_inv(l * x1 - y1);
r0.c1.c1 = r2.c1.c1;
r0.c1.c2 = -r1.c1.c2;
r0
} }
#[allow(clippy::needless_range_loop)] #[allow(clippy::needless_range_loop)]
@@ -544,22 +669,24 @@ mod gt {
let t_bits = b"110000000001000001000000100000000000000000000000000000000100000000000000001"; let t_bits = b"110000000001000001000000100000000000000000000000000000000100000000000000001";
let mut fk = fp_to_fp12(MontFp!("1")); let mut fk = fp_to_fp12(MontFp!("1"));
let p = p.normalize();
let q = q.normalize();
let mut qk = q; let mut qk = q;
for k in 1..t_log2 { for k in 1..t_log2 {
let lkk = ate_tangent_ev(qk, p); let m = qk.compute_m(qk).unwrap();
qk = qk + qk; let lkk = ate_tangent_ev(qk, p, m);
qk = qk.double(Some(m));
fk = fk.square() * lkk; fk = fk.square() * lkk;
if t_bits[k] == b'1' { if t_bits[k] == b'1' {
assert_ne!(q, qk); let m = q.compute_m(qk);
let lkp1 = if q != -qk { let new_qk = q.add_unequal(qk, m);
ate_line_ev(q, qk, p) if !new_qk.inner.infinity {
} else { fk *= ate_line_ev(q, p, m.unwrap());
fp_to_fp12(MontFp!("1")) }
}; qk = new_qk;
qk += q;
fk *= lkp1;
} }
} }
let mlo = MillerLoopOutput(fk); let mlo = MillerLoopOutput(fk);