mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
feat(zk): implement faster pke proof
- original work by Sarah El kazdadi co-authored-by: sarah el kazdadi <sarah.elkazdadi@zama.ai>
This commit is contained in:
@@ -22,6 +22,7 @@ rayon = "1.8.0"
|
||||
sha3 = "0.10.8"
|
||||
serde = { version = "~1.0", features = ["derive"] }
|
||||
zeroize = "1.7.0"
|
||||
num-bigint = "0.4.5"
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = "~1.0"
|
||||
|
||||
@@ -210,9 +210,14 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
if scalar.inner == MontFp!("2") {
|
||||
self.double()
|
||||
} else {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
|
||||
Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
@@ -245,9 +250,14 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
if scalar.inner == MontFp!("2") {
|
||||
self.double()
|
||||
} else {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
|
||||
Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
@@ -273,6 +283,9 @@ impl PairingGroupOps<bls12_381::Zp, bls12_381::G1, bls12_381::G2> for bls12_381:
|
||||
}
|
||||
|
||||
fn pairing(x: bls12_381::G1, y: bls12_381::G2) -> Self {
|
||||
if x == bls12_381::G1::ZERO || y == bls12_381::G2::ZERO {
|
||||
return Self::pairing(bls12_381::G1::ZERO, bls12_381::G2::GENERATOR);
|
||||
}
|
||||
Self::pairing(x, y)
|
||||
}
|
||||
}
|
||||
@@ -329,12 +342,21 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
if scalar.inner == MontFp!("2") {
|
||||
self.double()
|
||||
} else {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
|
||||
msm::msm_wnaf_g1_446(bases, scalars)
|
||||
// Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
// overhead seems to not be worth it outside of wasm
|
||||
if cfg!(target_family = "wasm") {
|
||||
msm::msm_wnaf_g1_446(bases, scalars)
|
||||
} else {
|
||||
Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_bytes(self) -> impl AsRef<[u8]> {
|
||||
@@ -365,9 +387,14 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
if scalar.inner == MontFp!("2") {
|
||||
self.double()
|
||||
} else {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
|
||||
Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
@@ -393,13 +420,16 @@ impl PairingGroupOps<bls12_446::Zp, bls12_446::G1, bls12_446::G2> for bls12_446:
|
||||
}
|
||||
|
||||
fn pairing(x: bls12_446::G1, y: bls12_446::G2) -> Self {
|
||||
if x == bls12_446::G1::ZERO || y == bls12_446::G2::ZERO {
|
||||
return Self::pairing(bls12_446::G1::ZERO, bls12_446::G2::GENERATOR);
|
||||
}
|
||||
Self::pairing(x, y)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Bls12_381;
|
||||
#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Bls12_446;
|
||||
|
||||
impl Curve for Bls12_381 {
|
||||
|
||||
@@ -55,6 +55,7 @@ mod g1 {
|
||||
}
|
||||
|
||||
impl G1Affine {
|
||||
#[track_caller]
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
|
||||
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
|
||||
G1 {
|
||||
@@ -124,6 +125,7 @@ mod g1 {
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self {
|
||||
use rayon::prelude::*;
|
||||
let bases = bases
|
||||
@@ -230,6 +232,7 @@ mod g2 {
|
||||
}
|
||||
|
||||
impl G2Affine {
|
||||
#[track_caller]
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
|
||||
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
|
||||
G2 {
|
||||
@@ -247,10 +250,10 @@ mod g2 {
|
||||
// functions. we cache it since it requires a Zp division
|
||||
// https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition
|
||||
pub(crate) fn compute_m(self, other: G2Affine) -> Option<crate::curve_446::Fq2> {
|
||||
let zero = crate::curve_446::Fq2::ZERO;
|
||||
|
||||
// in the context of elliptic curves, the point at infinity is the zero element of the
|
||||
// group
|
||||
let zero = crate::curve_446::Fq2::ZERO;
|
||||
|
||||
if self.inner.infinity || other.inner.infinity {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use ark_ec::short_weierstrass::Affine;
|
||||
use ark_ec::AffineRepr;
|
||||
use ark_ff::{AdditiveGroup, BigInt, BigInteger, Field, Fp, PrimeField};
|
||||
use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, PrimeField};
|
||||
use rayon::prelude::*;
|
||||
|
||||
fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
|
||||
@@ -46,6 +46,7 @@ fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<
|
||||
}
|
||||
|
||||
// Compute msm using windowed non-adjacent form
|
||||
#[track_caller]
|
||||
pub fn msm_wnaf_g1_446(
|
||||
bases: &[super::bls12_446::G1Affine],
|
||||
scalars: &[super::bls12_446::Zp],
|
||||
@@ -236,207 +237,3 @@ pub fn msm_wnaf_g1_446(
|
||||
total
|
||||
})
|
||||
}
|
||||
|
||||
// Compute msm using windowed non-adjacent form
|
||||
pub fn msm_wnaf_g1_446_extended(
|
||||
bases: &[super::bls12_446::G1Affine],
|
||||
scalars: &[super::bls12_446::Zp],
|
||||
) -> super::bls12_446::G1 {
|
||||
use super::bls12_446::*;
|
||||
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;
|
||||
|
||||
// let num_bits = 75usize;
|
||||
// let mask = BigInt([!0, (1 << 11) - 1, 0, 0, 0]);
|
||||
// let scalars = &*scalars
|
||||
// .par_iter()
|
||||
// .map(|x| x.inner.into_bigint())
|
||||
// .flat_map_iter(|x| (0..4).map(move |i| (x >> (75 * i)) & mask))
|
||||
// .collect::<Vec<_>>();
|
||||
|
||||
let num_bits = 150usize;
|
||||
let mask = BigInt([!0, !0, (1 << 22) - 1, 0, 0]);
|
||||
let scalars = &*scalars
|
||||
.par_iter()
|
||||
.map(|x| x.inner.into_bigint())
|
||||
.flat_map_iter(|x| (0..2).map(move |i| (x >> (150 * i)) & mask))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(bases.len(), scalars.len());
|
||||
|
||||
let size = bases.len();
|
||||
|
||||
let c = if size < 32 {
|
||||
3
|
||||
} else {
|
||||
// natural log approx
|
||||
(size.ilog2() as usize * 69 / 100) + 2
|
||||
};
|
||||
let c = c - 3;
|
||||
|
||||
let digits_count = (num_bits + c - 1) / c;
|
||||
let scalar_digits = scalars
|
||||
.into_par_iter()
|
||||
.flat_map_iter(|s| make_digits(s, c, num_bits))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let zero = G1Affine {
|
||||
inner: Affine::zero(),
|
||||
};
|
||||
|
||||
let window_sums: Vec<_> = (0..digits_count)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let n = 1 << c;
|
||||
let mut indices = vec![vec![]; n];
|
||||
let mut d = vec![BaseField::ZERO; n + 1];
|
||||
let mut e = vec![BaseField::ZERO; n + 1];
|
||||
|
||||
for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
|
||||
use core::cmp::Ordering;
|
||||
// digits is the digits thing of the first scalar?
|
||||
let scalar = digits[i];
|
||||
match 0.cmp(&scalar) {
|
||||
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
|
||||
Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx),
|
||||
Ordering::Equal => (),
|
||||
}
|
||||
}
|
||||
|
||||
let mut buckets = vec![zero; 1 << c];
|
||||
|
||||
loop {
|
||||
d[0] = BaseField::ONE;
|
||||
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
|
||||
if let Some(idx) = idx.last().copied() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let a = value.inner.x - bucket.inner.x;
|
||||
if a != BaseField::ZERO {
|
||||
d[k + 1] = d[k] * a;
|
||||
} else if value.inner.y == bucket.inner.y {
|
||||
d[k + 1] = d[k] * value.inner.y.double();
|
||||
} else {
|
||||
d[k + 1] = d[k];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
d[k + 1] = d[k];
|
||||
}
|
||||
e[n] = d[n].inverse().unwrap();
|
||||
|
||||
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices)
|
||||
.enumerate()
|
||||
.rev()
|
||||
{
|
||||
if let Some(idx) = idx.last().copied() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let a = value.inner.x - bucket.inner.x;
|
||||
if a != BaseField::ZERO {
|
||||
e[k] = e[k + 1] * a;
|
||||
} else if value.inner.y == bucket.inner.y {
|
||||
e[k] = e[k + 1] * value.inner.y.double();
|
||||
} else {
|
||||
e[k] = e[k + 1];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
e[k] = e[k + 1];
|
||||
}
|
||||
|
||||
let d = &d[..n];
|
||||
let e = &e[1..];
|
||||
|
||||
let mut empty = true;
|
||||
for ((&d, &e), (bucket, idx)) in core::iter::zip(
|
||||
core::iter::zip(d, e),
|
||||
core::iter::zip(&mut buckets, &mut indices),
|
||||
) {
|
||||
empty &= idx.len() <= 1;
|
||||
if let Some(idx) = idx.pop() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let x1 = bucket.inner.x;
|
||||
let x2 = value.inner.x;
|
||||
let y1 = bucket.inner.y;
|
||||
let y2 = value.inner.y;
|
||||
|
||||
let eq_x = x1 == x2;
|
||||
|
||||
if eq_x && y1 != y2 {
|
||||
bucket.inner.infinity = true;
|
||||
} else {
|
||||
let r = d * e;
|
||||
let m = if eq_x {
|
||||
let x1 = x1.square();
|
||||
x1 + x1.double()
|
||||
} else {
|
||||
y2 - y1
|
||||
};
|
||||
let m = m * r;
|
||||
|
||||
let x3 = m.square() - x1 - x2;
|
||||
let y3 = m * (x1 - x3) - y1;
|
||||
bucket.inner.x = x3;
|
||||
bucket.inner.y = y3;
|
||||
}
|
||||
} else {
|
||||
*bucket = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if empty {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut running_sum = G1::ZERO;
|
||||
let mut res = G1::ZERO;
|
||||
buckets.into_iter().rev().for_each(|b| {
|
||||
running_sum.inner += b.inner;
|
||||
res += running_sum;
|
||||
});
|
||||
res
|
||||
})
|
||||
.collect();
|
||||
|
||||
// We store the sum for the lowest window.
|
||||
let lowest = *window_sums.first().unwrap();
|
||||
|
||||
// We're traversing windows from high to low.
|
||||
lowest
|
||||
+ window_sums[1..]
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(G1::ZERO, |mut total, &sum_i| {
|
||||
total += sum_i;
|
||||
for _ in 0..c {
|
||||
total = total.double();
|
||||
}
|
||||
total
|
||||
})
|
||||
}
|
||||
|
||||
308
tfhe-zk-pok/src/four_squares.rs
Normal file
308
tfhe-zk-pok/src/four_squares.rs
Normal file
@@ -0,0 +1,308 @@
|
||||
use ark_ff::biginteger::arithmetic::widening_mul;
|
||||
use rand::prelude::*;
|
||||
|
||||
pub fn sqr<T: Copy + core::ops::Mul>(x: T) -> T::Output {
|
||||
x * x
|
||||
}
|
||||
|
||||
// copied from the standard library
|
||||
// since isqrt is unstable at the moment
|
||||
pub fn isqrt(this: u128) -> u128 {
|
||||
if this < 2 {
|
||||
return this;
|
||||
}
|
||||
|
||||
// The algorithm is based on the one presented in
|
||||
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
|
||||
// which cites as source the following C code:
|
||||
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.
|
||||
|
||||
let mut op = this;
|
||||
let mut res = 0;
|
||||
let mut one = 1 << (this.ilog2() & !1);
|
||||
|
||||
while one != 0 {
|
||||
if op >= res + one {
|
||||
op -= res + one;
|
||||
res = (res >> 1) + one;
|
||||
} else {
|
||||
res >>= 1;
|
||||
}
|
||||
one >>= 2;
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn half_gcd(p: u128, s: u128) -> u128 {
|
||||
let sq_p = isqrt(p as _);
|
||||
let mut a = p;
|
||||
let mut b = s;
|
||||
while b > sq_p {
|
||||
let r = a % b;
|
||||
a = b;
|
||||
b = r;
|
||||
}
|
||||
b
|
||||
}
|
||||
|
||||
fn modular_inv_2_64(p: u64) -> u64 {
|
||||
assert_eq!(p % 2, 1);
|
||||
|
||||
let mut old_r = p as u128;
|
||||
let mut r = 1u128 << 64;
|
||||
|
||||
let mut old_s = 1u64;
|
||||
let mut s = 0u64;
|
||||
|
||||
while r != 0 {
|
||||
let q = old_r / r;
|
||||
(old_r, r) = (r, old_r - q * r);
|
||||
|
||||
let q = q as u64;
|
||||
(old_s, s) = (s, old_s.wrapping_sub(q.wrapping_mul(s)));
|
||||
}
|
||||
|
||||
assert_eq!(u64::wrapping_mul(old_s, p), 1);
|
||||
old_s
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
struct Montgomery {
|
||||
p: u128,
|
||||
r2: u128,
|
||||
p_prime: u64,
|
||||
}
|
||||
|
||||
impl Montgomery {
|
||||
fn new(p: u128) -> Self {
|
||||
assert_ne!(p, 0);
|
||||
assert_eq!(p % 2, 1);
|
||||
|
||||
// r = 2^128
|
||||
// we want to compute r^2 mod p
|
||||
let r = p.wrapping_neg() % p;
|
||||
|
||||
let r = num_bigint::BigUint::from(r);
|
||||
let r2 = &r * &r;
|
||||
let r2 = r2 % p;
|
||||
let r2_digits = &*r2.to_u64_digits();
|
||||
|
||||
let r2 = match *r2_digits {
|
||||
[] => 0u128,
|
||||
[a] => a as u128,
|
||||
[a, b] => a as u128 | ((b as u128) << 64),
|
||||
_ => unreachable!("value modulo 128 bit integer should have at most two u64 digits"),
|
||||
};
|
||||
|
||||
let p_prime = modular_inv_2_64(p as u64).wrapping_neg();
|
||||
|
||||
Self { p, r2, p_prime }
|
||||
}
|
||||
|
||||
fn redc(self, lo: u128, hi: u128) -> u128 {
|
||||
let p0 = self.p as u64;
|
||||
let p1 = (self.p >> 64) as u64;
|
||||
|
||||
let t0 = lo as u64;
|
||||
let mut t1 = (lo >> 64) as u64;
|
||||
let mut t2 = hi as u64;
|
||||
let mut t3 = (hi >> 64) as u64;
|
||||
let mut t4 = 0u64;
|
||||
|
||||
{
|
||||
let m = u64::wrapping_mul(t0, self.p_prime);
|
||||
let mut c = 0u64;
|
||||
|
||||
let x = c as u128 + t0 as u128 + widening_mul(m, p0);
|
||||
// t0 = x as u64;
|
||||
c = (x >> 64) as u64;
|
||||
|
||||
let x = c as u128 + t1 as u128 + widening_mul(m, p1);
|
||||
t1 = x as u64;
|
||||
c = (x >> 64) as u64;
|
||||
|
||||
let x = c as u128 + t2 as u128;
|
||||
t2 = x as u64;
|
||||
c = (x >> 64) as u64;
|
||||
|
||||
let x = c as u128 + t3 as u128;
|
||||
t3 = x as u64;
|
||||
c = (x >> 64) as u64;
|
||||
|
||||
t4 += c;
|
||||
}
|
||||
|
||||
{
|
||||
let m = u64::wrapping_mul(t1, self.p_prime);
|
||||
let mut c = 0u64;
|
||||
|
||||
let x = c as u128 + t1 as u128 + widening_mul(m, p0);
|
||||
// t1 = x as u64;
|
||||
c = (x >> 64) as u64;
|
||||
|
||||
let x = c as u128 + t2 as u128 + widening_mul(m, p1);
|
||||
t2 = x as u64;
|
||||
c = (x >> 64) as u64;
|
||||
|
||||
let x = c as u128 + t3 as u128;
|
||||
t3 = x as u64;
|
||||
c = (x >> 64) as u64;
|
||||
|
||||
t4 += c;
|
||||
}
|
||||
|
||||
let mut s0 = t2;
|
||||
let mut s1 = t3;
|
||||
let s2 = t4;
|
||||
|
||||
if !(s2 == 0 && (s1, s0) < (p1, p0)) {
|
||||
let borrow;
|
||||
(s0, borrow) = u64::overflowing_sub(s0, p0);
|
||||
s1 = s1.wrapping_sub(p1).wrapping_sub(borrow as u64);
|
||||
}
|
||||
|
||||
s0 as u128 | ((s1 as u128) << 64)
|
||||
}
|
||||
|
||||
fn mont_from_natural(self, x: u128) -> u128 {
|
||||
self.mul(x, self.r2)
|
||||
}
|
||||
|
||||
fn natural_from_mont(self, x: u128) -> u128 {
|
||||
self.redc(x, 0)
|
||||
}
|
||||
|
||||
fn mul(self, x: u128, y: u128) -> u128 {
|
||||
let x0 = x as u64;
|
||||
let x1 = (x >> 64) as u64;
|
||||
let y0 = y as u64;
|
||||
let y1 = (y >> 64) as u64;
|
||||
|
||||
let lolo = widening_mul(x0, y0);
|
||||
let lohi = widening_mul(x0, y1);
|
||||
let hilo = widening_mul(x1, y0);
|
||||
let hihi = widening_mul(x1, y1);
|
||||
|
||||
let lo = lolo;
|
||||
let (lo, o0) = u128::overflowing_add(lo, lohi << 64);
|
||||
let (lo, o1) = u128::overflowing_add(lo, hilo << 64);
|
||||
|
||||
let hi = hihi + (lohi >> 64) + (hilo >> 64) + (o0 as u128 + o1 as u128);
|
||||
|
||||
self.redc(lo, hi)
|
||||
}
|
||||
|
||||
fn exp(self, x: u128, n: u128) -> u128 {
|
||||
if n == 0 {
|
||||
return 1;
|
||||
}
|
||||
let mut y = self.mont_from_natural(1);
|
||||
let mut x = x;
|
||||
let mut n = n;
|
||||
while n > 1 {
|
||||
if n % 2 == 1 {
|
||||
y = self.mul(x, y);
|
||||
}
|
||||
x = self.mul(x, x);
|
||||
n /= 2;
|
||||
}
|
||||
self.mul(x, y)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn four_squares(v: u128) -> [u64; 4] {
|
||||
let rng = &mut StdRng::seed_from_u64(0);
|
||||
|
||||
let f = v % 4;
|
||||
if f == 2 {
|
||||
let b = isqrt(v as _) as u64;
|
||||
|
||||
'main_loop: loop {
|
||||
let x = 2 + rng.gen::<u64>() % (b - 2);
|
||||
let y = 2 + rng.gen::<u64>() % (b - 2);
|
||||
|
||||
let (sum, o) = u128::overflowing_add(sqr(x as u128), sqr(y as u128));
|
||||
if o || sum > v {
|
||||
continue 'main_loop;
|
||||
}
|
||||
|
||||
let p = v - sum;
|
||||
|
||||
if p == 0 || p == 1 {
|
||||
return [0, p as u64, x, y];
|
||||
}
|
||||
|
||||
if p % 4 != 1 {
|
||||
continue 'main_loop;
|
||||
}
|
||||
|
||||
let mut d = p - 1;
|
||||
let mut s = 0u32;
|
||||
while d % 2 == 0 {
|
||||
d /= 2;
|
||||
s += 1;
|
||||
}
|
||||
let d = d;
|
||||
let s = s;
|
||||
|
||||
let mont = Montgomery::new(p);
|
||||
let a = 2 + (rng.gen::<u128>() % (p - 3));
|
||||
|
||||
let mut sqrt = 0;
|
||||
{
|
||||
let a = mont.mont_from_natural(a);
|
||||
let one = mont.mont_from_natural(1);
|
||||
let neg_one = p - one;
|
||||
|
||||
let mut x = mont.exp(a, d);
|
||||
let mut y = 0;
|
||||
|
||||
for _ in 0..s {
|
||||
y = mont.mul(x, x);
|
||||
if y == one && x != one && x != neg_one {
|
||||
continue 'main_loop;
|
||||
}
|
||||
if y == neg_one {
|
||||
sqrt = x;
|
||||
}
|
||||
x = y;
|
||||
}
|
||||
if y != one {
|
||||
continue 'main_loop;
|
||||
}
|
||||
}
|
||||
if sqrt == 0 {
|
||||
continue 'main_loop;
|
||||
}
|
||||
|
||||
let i = mont.natural_from_mont(sqrt);
|
||||
let i = if i <= p / 2 { p - i } else { i };
|
||||
let z = half_gcd(p, i) as u64;
|
||||
let w = isqrt(p - sqr(z as u128)) as u64;
|
||||
|
||||
if p != sqr(z as u128) + sqr(w as u128) {
|
||||
continue 'main_loop;
|
||||
}
|
||||
|
||||
return [x, y, z, w];
|
||||
}
|
||||
} else if f == 0 {
|
||||
four_squares(v / 4).map(|x| x + x)
|
||||
} else {
|
||||
let mut r = four_squares(2 * v);
|
||||
r.sort_by_key(|&x| {
|
||||
if x % 2 == 0 {
|
||||
-1 - ((x / 2) as i64)
|
||||
} else {
|
||||
(x / 2) as i64
|
||||
}
|
||||
});
|
||||
[
|
||||
(r[0] + r[1]) / 2,
|
||||
(r[0] - r[1]) / 2,
|
||||
(r[3] + r[2]) / 2,
|
||||
(r[3] - r[2]) / 2,
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -3,3 +3,5 @@ pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Vali
|
||||
pub mod curve_446;
|
||||
pub mod curve_api;
|
||||
pub mod proofs;
|
||||
|
||||
mod four_squares;
|
||||
|
||||
@@ -141,5 +141,6 @@ pub const HASH_METADATA_LEN_BYTES: usize = 256;
|
||||
pub mod binary;
|
||||
pub mod index;
|
||||
pub mod pke;
|
||||
pub mod pke_v2;
|
||||
pub mod range;
|
||||
pub mod rlwe;
|
||||
|
||||
@@ -101,7 +101,7 @@ impl<G: Curve> PublicCommit<G> {
|
||||
b,
|
||||
c1,
|
||||
c2,
|
||||
__marker: Default::default(),
|
||||
__marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2194
tfhe-zk-pok/src/proofs/pke_v2.rs
Normal file
2194
tfhe-zk-pok/src/proofs/pke_v2.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user