From 19e00c484bdcf56a316dc301b1d6486f3ab09287 Mon Sep 17 00:00:00 2001 From: sarah el kazdadi Date: Tue, 25 Jun 2024 15:29:45 +0200 Subject: [PATCH] feat(zk): zk perf improvements --- tfhe-zk-pok/Cargo.toml | 6 +- tfhe-zk-pok/src/curve_api/bls12_446.rs | 277 ++++++++++++++++++------- 2 files changed, 205 insertions(+), 78 deletions(-) diff --git a/tfhe-zk-pok/Cargo.toml b/tfhe-zk-pok/Cargo.toml index 9d1a65bf9..e35114136 100644 --- a/tfhe-zk-pok/Cargo.toml +++ b/tfhe-zk-pok/Cargo.toml @@ -13,9 +13,9 @@ description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryp [dependencies] ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" } -ark-ec = { package = "tfhe-ark-ec", version = "0.4.2" } -ark-ff = { package = "tfhe-ark-ff", version = "0.4.3" } -ark-poly = { package = "tfhe-ark-poly", version = "0.4.2" } +ark-ec = { package = "tfhe-ark-ec", version = "0.4.2", features = ["parallel"] } +ark-ff = { package = "tfhe-ark-ff", version = "0.4.3", features = ["parallel"] } +ark-poly = { package = "tfhe-ark-poly", version = "0.4.2", features = ["parallel"] } ark-serialize = { version = "0.4.2" } rand = "0.8.5" rayon = "1.8.0" diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index 1ea2aa572..f2450389b 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -242,6 +242,96 @@ mod g2 { .unwrap(), } } + + // m is an intermediate variable that's used in both the curve point addition and pairing + // functions. we cache it since it requires a Zp division + // https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition + pub(crate) fn compute_m(self, other: G2Affine) -> Option { + let zero = crate::curve_446::Fq2::ZERO; + + // in the context of elliptic curves, the point at infinity is the zero element of the + // group + if self.inner.infinity || other.inner.infinity { + return None; + } + + if self == other { + let x = self.inner.x; + let y = self.inner.y; + if y == zero { + None + } else { + let xx = x.square(); + Some((xx.double() + xx) / y.double()) + } + } else { + let x1 = self.inner.x; + let y1 = self.inner.y; + let x2 = other.inner.x; + let y2 = other.inner.y; + + let x_delta = x2 - x1; + let y_delta = y2 - y1; + + if x_delta == zero { + None + } else { + Some(y_delta / x_delta) + } + } + } + + pub(crate) fn double(self, m: Option) -> Self { + // in the context of elliptic curves, the point at infinity is the zero element of the + // group + if self.inner.infinity { + return self; + } + + let mut result = self; + + let x = self.inner.x; + let y = self.inner.y; + + if let Some(m) = m { + let x3 = m.square() - x.double(); + let y3 = m * (x - x3) - y; + + (result.inner.x, result.inner.y) = (x3, y3); + } else { + result.inner.infinity = true; + } + + result + } + + pub(crate) fn add_unequal(self, other: G2Affine, m: Option) -> Self { + // in the context of elliptic curves, the point at infinity is the zero element of the + // group + if self.inner.infinity { + return other; + } + if other.inner.infinity { + return self; + } + + let mut result = self; + + let x1 = self.inner.x; + let y1 = self.inner.y; + let x2 = other.inner.x; + + if let Some(m) = m { + let x3 = m.square() - x1 - x2; + let y3 = m * (x1 - x3) - y1; + + (result.inner.x, result.inner.y) = (x3, y3); + } else { + result.inner.infinity = true; + } + + result + } } #[derive( @@ -373,9 +463,9 @@ mod g2 { } pub fn double(self) -> Self { - Self { - inner: self.inner.double(), - } + let mut this = self; + this.inner.double_in_place(); + this } } @@ -431,51 +521,79 @@ mod g2 { } mod gt { + use crate::curve_446::{Fq, Fq12, Fq2}; + use super::*; - use ark_ec::bls12::Bls12Config; use ark_ec::pairing::{MillerLoopOutput, Pairing}; - use ark_ff::{CubicExtField, Fp12, Fp2, QuadExtField}; + use ark_ff::{CubicExtField, QuadExtField}; type Bls = crate::curve_446::Bls12_446; - type Config = crate::curve_446::Config; - const ONE: Fp2<::Fp2Config> = QuadExtField { - c0: MontFp!("1"), - c1: MontFp!("0"), - }; - const ZERO: Fp2<::Fp2Config> = QuadExtField { + const ZERO: Fq2 = QuadExtField { c0: MontFp!("0"), c1: MontFp!("0"), }; - const U1: Fp12<::Fp12Config> = QuadExtField { - c0: CubicExtField { - c0: ZERO, - c1: ZERO, - c2: ZERO, - }, - c1: CubicExtField { - c0: ONE, - c1: ZERO, - c2: ZERO, - }, + // computed by copying the result from + // let two: Fq = MontFp!("2"); println!("{}", two.inverse().unwrap()), which we can't compute in + // a const context; + const TWO_INV: Fq = { + MontFp!("86412351771428577990035638289747981121746346761394949218917418178192828331138736448451251370148591845087981000773214233672031082665302") }; - const U3: Fp12<::Fp12Config> = QuadExtField { - c0: CubicExtField { - c0: ZERO, - c1: ZERO, - c2: ZERO, - }, - c1: CubicExtField { - c0: ZERO, - c1: ONE, - c2: ZERO, - }, + const TWO_INV_MINUS_1: Fq = { + MontFp!("86412351771428577990035638289747981121746346761394949218917418178192828331138736448451251370148591845087981000773214233672031082665301") }; - const fn fp2_to_fp12( - x: Fp2<::Fp2Config>, - ) -> Fp12<::Fp12Config> { + // the only non zero value in inv(U1) and inv(U3), which come from Olivier's equations. + const C: Fq2 = QuadExtField { + c0: TWO_INV, + c1: TWO_INV_MINUS_1, + }; + + fn fp2_mul_c(x: Fq2) -> Fq2 { + let x0_c0 = x.c0 * C.c0; + let x1_c0 = x.c1 * C.c0; + + let x0_c1 = x0_c0 - x.c0; + let x1_c1 = x1_c0 - x.c1; + + QuadExtField { + c0: x0_c0 - x1_c1, + c1: x0_c1 + x1_c0, + } + } + + fn fp2_mul_u1_inv(x: Fq2) -> Fq12 { + QuadExtField { + c0: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: fp2_mul_c(x), + }, + } + } + + fn fp2_mul_u3_inv(x: Fq2) -> Fq12 { + QuadExtField { + c0: CubicExtField { + c0: ZERO, + c1: ZERO, + c2: ZERO, + }, + c1: CubicExtField { + c0: ZERO, + c1: fp2_mul_c(x), + c2: ZERO, + }, + } + } + + const fn fp2_to_fp12(x: Fq2) -> Fq12 { QuadExtField { c0: CubicExtField { c0: x, @@ -490,52 +608,59 @@ mod gt { } } - const fn fp_to_fp12( - x: ::Fp, - ) -> Fp12<::Fp12Config> { - fp2_to_fp12(QuadExtField { + const fn fp_to_fp2(x: Fq) -> Fq2 { + QuadExtField { c0: x, c1: MontFp!("0"), - }) + } } - fn ate_tangent_ev(qt: G2, evpt: G1) -> Fp12<::Fp12Config> { - let qt = qt.inner.into_affine(); - let evpt = evpt.inner.into_affine(); + const fn fp_to_fp12(x: Fq) -> Fq12 { + fp2_to_fp12(fp_to_fp2(x)) + } + + fn ate_tangent_ev(qt: G2Affine, evpt: G1Affine, m: Fq2) -> Fq12 { + let qt = qt.inner; + let evpt = evpt.inner; let (xt, yt) = (qt.x, qt.y); let (xe, ye) = (evpt.x, evpt.y); - let xt = fp2_to_fp12(xt); - let yt = fp2_to_fp12(yt); - let xe = fp_to_fp12(xe); - let ye = fp_to_fp12(ye); + let l = m; + let mut l_xe = l; + l_xe.c0 *= xe; + l_xe.c1 *= xe; - let three = fp_to_fp12(MontFp!("3")); - let two = fp_to_fp12(MontFp!("2")); + let mut r0 = fp_to_fp12(ye); + let r1 = fp2_mul_u1_inv(l_xe); + let r2 = fp2_mul_u3_inv(l * xt - yt); - let l = three * xt.square() / (two * yt); - ye - (l * xe) / U1 + (l * xt - yt) / U3 + r0.c1.c1 = r2.c1.c1; + r0.c1.c2 = -r1.c1.c2; + + r0 } - fn ate_line_ev(q1: G2, q2: G2, evpt: G1) -> Fp12<::Fp12Config> { - let q1 = q1.inner.into_affine(); - let q2 = q2.inner.into_affine(); - let evpt = evpt.inner.into_affine(); + fn ate_line_ev(q1: G2Affine, evpt: G1Affine, m: Fq2) -> Fq12 { + let q1 = q1.inner; + let evpt = evpt.inner; let (x1, y1) = (q1.x, q1.y); - let (x2, y2) = (q2.x, q2.y); let (xe, ye) = (evpt.x, evpt.y); - let x1 = fp2_to_fp12(x1); - let y1 = fp2_to_fp12(y1); - let x2 = fp2_to_fp12(x2); - let y2 = fp2_to_fp12(y2); - let xe = fp_to_fp12(xe); - let ye = fp_to_fp12(ye); + let l = m; + let mut l_xe = l; + l_xe.c0 *= xe; + l_xe.c1 *= xe; - let l = (y2 - y1) / (x2 - x1); - ye - (l * xe) / U1 + (l * x1 - y1) / U3 + let mut r0 = fp_to_fp12(ye); + let r1 = fp2_mul_u1_inv(l * fp_to_fp2(xe)); + let r2 = fp2_mul_u3_inv(l * x1 - y1); + + r0.c1.c1 = r2.c1.c1; + r0.c1.c2 = -r1.c1.c2; + + r0 } #[allow(clippy::needless_range_loop)] @@ -544,22 +669,24 @@ mod gt { let t_bits = b"110000000001000001000000100000000000000000000000000000000100000000000000001"; let mut fk = fp_to_fp12(MontFp!("1")); + let p = p.normalize(); + let q = q.normalize(); + let mut qk = q; for k in 1..t_log2 { - let lkk = ate_tangent_ev(qk, p); - qk = qk + qk; + let m = qk.compute_m(qk).unwrap(); + let lkk = ate_tangent_ev(qk, p, m); + qk = qk.double(Some(m)); fk = fk.square() * lkk; if t_bits[k] == b'1' { - assert_ne!(q, qk); - let lkp1 = if q != -qk { - ate_line_ev(q, qk, p) - } else { - fp_to_fp12(MontFp!("1")) - }; - qk += q; - fk *= lkp1; + let m = q.compute_m(qk); + let new_qk = q.add_unequal(qk, m); + if !new_qk.inner.infinity { + fk *= ate_line_ev(q, p, m.unwrap()); + } + qk = new_qk; } } let mlo = MillerLoopOutput(fk);