feat(zk): implement faster pke proof

- original work by Sarah El kazdadi

co-authored-by: sarah el kazdadi <sarah.elkazdadi@zama.ai>
This commit is contained in:
Arthur Meyre
2024-06-25 15:19:38 +02:00
parent 32b45ac4bc
commit ce9da12e65
9 changed files with 2552 additions and 216 deletions

View File

@@ -22,6 +22,7 @@ rayon = "1.8.0"
sha3 = "0.10.8"
serde = { version = "~1.0", features = ["derive"] }
zeroize = "1.7.0"
num-bigint = "0.4.5"
[dev-dependencies]
serde_json = "~1.0"

View File

@@ -210,9 +210,14 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
}
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}
#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
@@ -245,9 +250,14 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
}
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}
#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
@@ -273,6 +283,9 @@ impl PairingGroupOps<bls12_381::Zp, bls12_381::G1, bls12_381::G2> for bls12_381:
}
fn pairing(x: bls12_381::G1, y: bls12_381::G2) -> Self {
if x == bls12_381::G1::ZERO || y == bls12_381::G2::ZERO {
return Self::pairing(bls12_381::G1::ZERO, bls12_381::G2::GENERATOR);
}
Self::pairing(x, y)
}
}
@@ -329,12 +342,21 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
}
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}
#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
msm::msm_wnaf_g1_446(bases, scalars)
// Self::Affine::multi_mul_scalar(bases, scalars)
// overhead seems to not be worth it outside of wasm
if cfg!(target_family = "wasm") {
msm::msm_wnaf_g1_446(bases, scalars)
} else {
Self::Affine::multi_mul_scalar(bases, scalars)
}
}
fn to_bytes(self) -> impl AsRef<[u8]> {
@@ -365,9 +387,14 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
}
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
self.mul_scalar(scalar)
if scalar.inner == MontFp!("2") {
self.double()
} else {
self.mul_scalar(scalar)
}
}
#[track_caller]
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
@@ -393,13 +420,16 @@ impl PairingGroupOps<bls12_446::Zp, bls12_446::G1, bls12_446::G2> for bls12_446:
}
fn pairing(x: bls12_446::G1, y: bls12_446::G2) -> Self {
if x == bls12_446::G1::ZERO || y == bls12_446::G2::ZERO {
return Self::pairing(bls12_446::G1::ZERO, bls12_446::G2::GENERATOR);
}
Self::pairing(x, y)
}
}
#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
pub struct Bls12_381;
#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
pub struct Bls12_446;
impl Curve for Bls12_381 {

View File

@@ -55,6 +55,7 @@ mod g1 {
}
impl G1Affine {
#[track_caller]
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G1 {
@@ -124,6 +125,7 @@ mod g1 {
}
}
#[track_caller]
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self {
use rayon::prelude::*;
let bases = bases
@@ -230,6 +232,7 @@ mod g2 {
}
impl G2Affine {
#[track_caller]
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G2 {
@@ -247,10 +250,10 @@ mod g2 {
// functions. we cache it since it requires a Zp division
// https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition
pub(crate) fn compute_m(self, other: G2Affine) -> Option<crate::curve_446::Fq2> {
let zero = crate::curve_446::Fq2::ZERO;
// in the context of elliptic curves, the point at infinity is the zero element of the
// group
let zero = crate::curve_446::Fq2::ZERO;
if self.inner.infinity || other.inner.infinity {
return None;
}

View File

@@ -1,6 +1,6 @@
use ark_ec::short_weierstrass::Affine;
use ark_ec::AffineRepr;
use ark_ff::{AdditiveGroup, BigInt, BigInteger, Field, Fp, PrimeField};
use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, PrimeField};
use rayon::prelude::*;
fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
@@ -46,6 +46,7 @@ fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<
}
// Compute msm using windowed non-adjacent form
#[track_caller]
pub fn msm_wnaf_g1_446(
bases: &[super::bls12_446::G1Affine],
scalars: &[super::bls12_446::Zp],
@@ -236,207 +237,3 @@ pub fn msm_wnaf_g1_446(
total
})
}
// Compute msm using windowed non-adjacent form
pub fn msm_wnaf_g1_446_extended(
bases: &[super::bls12_446::G1Affine],
scalars: &[super::bls12_446::Zp],
) -> super::bls12_446::G1 {
use super::bls12_446::*;
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;
// let num_bits = 75usize;
// let mask = BigInt([!0, (1 << 11) - 1, 0, 0, 0]);
// let scalars = &*scalars
// .par_iter()
// .map(|x| x.inner.into_bigint())
// .flat_map_iter(|x| (0..4).map(move |i| (x >> (75 * i)) & mask))
// .collect::<Vec<_>>();
let num_bits = 150usize;
let mask = BigInt([!0, !0, (1 << 22) - 1, 0, 0]);
let scalars = &*scalars
.par_iter()
.map(|x| x.inner.into_bigint())
.flat_map_iter(|x| (0..2).map(move |i| (x >> (150 * i)) & mask))
.collect::<Vec<_>>();
assert_eq!(bases.len(), scalars.len());
let size = bases.len();
let c = if size < 32 {
3
} else {
// natural log approx
(size.ilog2() as usize * 69 / 100) + 2
};
let c = c - 3;
let digits_count = (num_bits + c - 1) / c;
let scalar_digits = scalars
.into_par_iter()
.flat_map_iter(|s| make_digits(s, c, num_bits))
.collect::<Vec<_>>();
let zero = G1Affine {
inner: Affine::zero(),
};
let window_sums: Vec<_> = (0..digits_count)
.into_par_iter()
.map(|i| {
let n = 1 << c;
let mut indices = vec![vec![]; n];
let mut d = vec![BaseField::ZERO; n + 1];
let mut e = vec![BaseField::ZERO; n + 1];
for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
use core::cmp::Ordering;
// digits is the digits thing of the first scalar?
let scalar = digits[i];
match 0.cmp(&scalar) {
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx),
Ordering::Equal => (),
}
}
let mut buckets = vec![zero; 1 << c];
loop {
d[0] = BaseField::ONE;
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
d[k + 1] = d[k] * a;
} else if value.inner.y == bucket.inner.y {
d[k + 1] = d[k] * value.inner.y.double();
} else {
d[k + 1] = d[k];
}
continue;
}
}
d[k + 1] = d[k];
}
e[n] = d[n].inverse().unwrap();
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices)
.enumerate()
.rev()
{
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
e[k] = e[k + 1] * a;
} else if value.inner.y == bucket.inner.y {
e[k] = e[k + 1] * value.inner.y.double();
} else {
e[k] = e[k + 1];
}
continue;
}
}
e[k] = e[k + 1];
}
let d = &d[..n];
let e = &e[1..];
let mut empty = true;
for ((&d, &e), (bucket, idx)) in core::iter::zip(
core::iter::zip(d, e),
core::iter::zip(&mut buckets, &mut indices),
) {
empty &= idx.len() <= 1;
if let Some(idx) = idx.pop() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let x1 = bucket.inner.x;
let x2 = value.inner.x;
let y1 = bucket.inner.y;
let y2 = value.inner.y;
let eq_x = x1 == x2;
if eq_x && y1 != y2 {
bucket.inner.infinity = true;
} else {
let r = d * e;
let m = if eq_x {
let x1 = x1.square();
x1 + x1.double()
} else {
y2 - y1
};
let m = m * r;
let x3 = m.square() - x1 - x2;
let y3 = m * (x1 - x3) - y1;
bucket.inner.x = x3;
bucket.inner.y = y3;
}
} else {
*bucket = value;
}
}
}
if empty {
break;
}
}
let mut running_sum = G1::ZERO;
let mut res = G1::ZERO;
buckets.into_iter().rev().for_each(|b| {
running_sum.inner += b.inner;
res += running_sum;
});
res
})
.collect();
// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();
// We're traversing windows from high to low.
lowest
+ window_sums[1..]
.iter()
.rev()
.fold(G1::ZERO, |mut total, &sum_i| {
total += sum_i;
for _ in 0..c {
total = total.double();
}
total
})
}

View File

@@ -0,0 +1,308 @@
use ark_ff::biginteger::arithmetic::widening_mul;
use rand::prelude::*;
pub fn sqr<T: Copy + core::ops::Mul>(x: T) -> T::Output {
x * x
}
// copied from the standard library
// since isqrt is unstable at the moment
pub fn isqrt(this: u128) -> u128 {
if this < 2 {
return this;
}
// The algorithm is based on the one presented in
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
// which cites as source the following C code:
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.
let mut op = this;
let mut res = 0;
let mut one = 1 << (this.ilog2() & !1);
while one != 0 {
if op >= res + one {
op -= res + one;
res = (res >> 1) + one;
} else {
res >>= 1;
}
one >>= 2;
}
res
}
fn half_gcd(p: u128, s: u128) -> u128 {
let sq_p = isqrt(p as _);
let mut a = p;
let mut b = s;
while b > sq_p {
let r = a % b;
a = b;
b = r;
}
b
}
fn modular_inv_2_64(p: u64) -> u64 {
assert_eq!(p % 2, 1);
let mut old_r = p as u128;
let mut r = 1u128 << 64;
let mut old_s = 1u64;
let mut s = 0u64;
while r != 0 {
let q = old_r / r;
(old_r, r) = (r, old_r - q * r);
let q = q as u64;
(old_s, s) = (s, old_s.wrapping_sub(q.wrapping_mul(s)));
}
assert_eq!(u64::wrapping_mul(old_s, p), 1);
old_s
}
#[derive(Copy, Clone, Debug)]
struct Montgomery {
p: u128,
r2: u128,
p_prime: u64,
}
impl Montgomery {
fn new(p: u128) -> Self {
assert_ne!(p, 0);
assert_eq!(p % 2, 1);
// r = 2^128
// we want to compute r^2 mod p
let r = p.wrapping_neg() % p;
let r = num_bigint::BigUint::from(r);
let r2 = &r * &r;
let r2 = r2 % p;
let r2_digits = &*r2.to_u64_digits();
let r2 = match *r2_digits {
[] => 0u128,
[a] => a as u128,
[a, b] => a as u128 | ((b as u128) << 64),
_ => unreachable!("value modulo 128 bit integer should have at most two u64 digits"),
};
let p_prime = modular_inv_2_64(p as u64).wrapping_neg();
Self { p, r2, p_prime }
}
fn redc(self, lo: u128, hi: u128) -> u128 {
let p0 = self.p as u64;
let p1 = (self.p >> 64) as u64;
let t0 = lo as u64;
let mut t1 = (lo >> 64) as u64;
let mut t2 = hi as u64;
let mut t3 = (hi >> 64) as u64;
let mut t4 = 0u64;
{
let m = u64::wrapping_mul(t0, self.p_prime);
let mut c = 0u64;
let x = c as u128 + t0 as u128 + widening_mul(m, p0);
// t0 = x as u64;
c = (x >> 64) as u64;
let x = c as u128 + t1 as u128 + widening_mul(m, p1);
t1 = x as u64;
c = (x >> 64) as u64;
let x = c as u128 + t2 as u128;
t2 = x as u64;
c = (x >> 64) as u64;
let x = c as u128 + t3 as u128;
t3 = x as u64;
c = (x >> 64) as u64;
t4 += c;
}
{
let m = u64::wrapping_mul(t1, self.p_prime);
let mut c = 0u64;
let x = c as u128 + t1 as u128 + widening_mul(m, p0);
// t1 = x as u64;
c = (x >> 64) as u64;
let x = c as u128 + t2 as u128 + widening_mul(m, p1);
t2 = x as u64;
c = (x >> 64) as u64;
let x = c as u128 + t3 as u128;
t3 = x as u64;
c = (x >> 64) as u64;
t4 += c;
}
let mut s0 = t2;
let mut s1 = t3;
let s2 = t4;
if !(s2 == 0 && (s1, s0) < (p1, p0)) {
let borrow;
(s0, borrow) = u64::overflowing_sub(s0, p0);
s1 = s1.wrapping_sub(p1).wrapping_sub(borrow as u64);
}
s0 as u128 | ((s1 as u128) << 64)
}
fn mont_from_natural(self, x: u128) -> u128 {
self.mul(x, self.r2)
}
fn natural_from_mont(self, x: u128) -> u128 {
self.redc(x, 0)
}
fn mul(self, x: u128, y: u128) -> u128 {
let x0 = x as u64;
let x1 = (x >> 64) as u64;
let y0 = y as u64;
let y1 = (y >> 64) as u64;
let lolo = widening_mul(x0, y0);
let lohi = widening_mul(x0, y1);
let hilo = widening_mul(x1, y0);
let hihi = widening_mul(x1, y1);
let lo = lolo;
let (lo, o0) = u128::overflowing_add(lo, lohi << 64);
let (lo, o1) = u128::overflowing_add(lo, hilo << 64);
let hi = hihi + (lohi >> 64) + (hilo >> 64) + (o0 as u128 + o1 as u128);
self.redc(lo, hi)
}
fn exp(self, x: u128, n: u128) -> u128 {
if n == 0 {
return 1;
}
let mut y = self.mont_from_natural(1);
let mut x = x;
let mut n = n;
while n > 1 {
if n % 2 == 1 {
y = self.mul(x, y);
}
x = self.mul(x, x);
n /= 2;
}
self.mul(x, y)
}
}
pub fn four_squares(v: u128) -> [u64; 4] {
let rng = &mut StdRng::seed_from_u64(0);
let f = v % 4;
if f == 2 {
let b = isqrt(v as _) as u64;
'main_loop: loop {
let x = 2 + rng.gen::<u64>() % (b - 2);
let y = 2 + rng.gen::<u64>() % (b - 2);
let (sum, o) = u128::overflowing_add(sqr(x as u128), sqr(y as u128));
if o || sum > v {
continue 'main_loop;
}
let p = v - sum;
if p == 0 || p == 1 {
return [0, p as u64, x, y];
}
if p % 4 != 1 {
continue 'main_loop;
}
let mut d = p - 1;
let mut s = 0u32;
while d % 2 == 0 {
d /= 2;
s += 1;
}
let d = d;
let s = s;
let mont = Montgomery::new(p);
let a = 2 + (rng.gen::<u128>() % (p - 3));
let mut sqrt = 0;
{
let a = mont.mont_from_natural(a);
let one = mont.mont_from_natural(1);
let neg_one = p - one;
let mut x = mont.exp(a, d);
let mut y = 0;
for _ in 0..s {
y = mont.mul(x, x);
if y == one && x != one && x != neg_one {
continue 'main_loop;
}
if y == neg_one {
sqrt = x;
}
x = y;
}
if y != one {
continue 'main_loop;
}
}
if sqrt == 0 {
continue 'main_loop;
}
let i = mont.natural_from_mont(sqrt);
let i = if i <= p / 2 { p - i } else { i };
let z = half_gcd(p, i) as u64;
let w = isqrt(p - sqr(z as u128)) as u64;
if p != sqr(z as u128) + sqr(w as u128) {
continue 'main_loop;
}
return [x, y, z, w];
}
} else if f == 0 {
four_squares(v / 4).map(|x| x + x)
} else {
let mut r = four_squares(2 * v);
r.sort_by_key(|&x| {
if x % 2 == 0 {
-1 - ((x / 2) as i64)
} else {
(x / 2) as i64
}
});
[
(r[0] + r[1]) / 2,
(r[0] - r[1]) / 2,
(r[3] + r[2]) / 2,
(r[3] - r[2]) / 2,
]
}
}

View File

@@ -3,3 +3,5 @@ pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Vali
pub mod curve_446;
pub mod curve_api;
pub mod proofs;
mod four_squares;

View File

@@ -141,5 +141,6 @@ pub const HASH_METADATA_LEN_BYTES: usize = 256;
pub mod binary;
pub mod index;
pub mod pke;
pub mod pke_v2;
pub mod range;
pub mod rlwe;

View File

@@ -101,7 +101,7 @@ impl<G: Curve> PublicCommit<G> {
b,
c1,
c2,
__marker: Default::default(),
__marker: PhantomData,
}
}
}

File diff suppressed because it is too large Load Diff