feat(zk): improve performance of zk pke proofs

This commit is contained in:
sarah el kazdadi
2024-06-17 10:14:32 +02:00
committed by sarah
parent dcd8224a7e
commit deebe09a8c
12 changed files with 853 additions and 114 deletions

View File

@@ -12,15 +12,15 @@ description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryp
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ark-bls12-381 = "0.4.0"
ark-ec = "0.4.2"
ark-ff = "0.4.2"
ark-poly = "0.4.2"
ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" }
ark-ec = { package = "tfhe-ark-ec", version = "0.4.2" }
ark-ff = { package = "tfhe-ark-ff", version = "0.4.3" }
ark-poly = { package = "tfhe-ark-poly", version = "0.4.2" }
ark-serialize = { version = "0.4.2" }
rand = "0.8.5"
rayon = "1.8.0"
sha3 = "0.10.8"
serde = { version = "~1.0", features = ["derive"] }
ark-serialize = { version = "0.4.2" }
zeroize = "1.7.0"
[dev-dependencies]

View File

@@ -466,8 +466,8 @@ pub mod g1 {
use ark_ec::bls12::Bls12Config;
use ark_ec::models::CurveConfig;
use ark_ec::short_weierstrass::{Affine, SWCurveConfig};
use ark_ec::{bls12, AffineRepr, Group};
use ark_ff::{Field, MontFp, One, PrimeField, Zero};
use ark_ec::{bls12, AdditiveGroup, AffineRepr, PrimeGroup};
use ark_ff::{MontFp, One, PrimeField, Zero};
use ark_serialize::{Compress, SerializationError};
use core::ops::Neg;
@@ -631,7 +631,7 @@ pub mod g2 {
use ark_ec::models::CurveConfig;
use ark_ec::short_weierstrass::{Affine, SWCurveConfig};
use ark_ec::{bls12, AffineRepr};
use ark_ff::{MontFp, Zero};
use ark_ff::MontFp;
use ark_serialize::{Compress, SerializationError};
pub type G2Affine = bls12::G2Affine<super::Config>;

View File

@@ -1,4 +1,4 @@
use ark_ec::{CurveGroup, Group, VariableBaseMSM};
use ark_ec::{AdditiveGroup as Group, CurveGroup, VariableBaseMSM};
use ark_ff::{BigInt, Field, MontFp, Zero};
use ark_poly::univariate::DensePolynomial;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
@@ -37,6 +37,8 @@ impl<T: fmt::Display + PartialEq + Field> fmt::Debug for MontIntDisplay<'_, T> {
}
}
pub mod msm;
pub mod bls12_381;
pub mod bls12_446;
@@ -44,6 +46,7 @@ pub trait FieldOps:
Copy
+ Send
+ Sync
+ core::fmt::Debug
+ core::ops::AddAssign<Self>
+ core::ops::SubAssign<Self>
+ core::ops::Add<Self, Output = Self>
@@ -62,6 +65,7 @@ pub trait FieldOps:
fn to_bytes(self) -> impl AsRef<[u8]>;
fn rand(rng: &mut dyn rand::RngCore) -> Self;
fn hash(values: &mut [Self], data: &[&[u8]]);
fn hash_128bit(values: &mut [Self], data: &[&[u8]]);
fn poly_mul(p: &[Self], q: &[Self]) -> Vec<Self>;
fn poly_sub(p: &[Self], q: &[Self]) -> Vec<Self> {
use core::iter::zip;
@@ -113,10 +117,22 @@ pub trait CurveGroupOps<Zp>:
const GENERATOR: Self;
const BYTE_SIZE: usize;
type Affine: Copy
+ Send
+ Sync
+ core::fmt::Debug
+ serde::Serialize
+ for<'de> serde::Deserialize<'de>
+ CanonicalSerialize
+ CanonicalDeserialize;
fn projective(affine: Self::Affine) -> Self;
fn mul_scalar(self, scalar: Zp) -> Self;
fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self;
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[Zp]) -> Self;
fn to_bytes(self) -> impl AsRef<[u8]>;
fn double(self) -> Self;
fn normalize(self) -> Self::Affine;
}
pub trait PairingGroupOps<Zp, G1, G2>:
@@ -164,6 +180,9 @@ impl FieldOps for bls12_381::Zp {
fn hash(values: &mut [Self], data: &[&[u8]]) {
Self::hash(values, data)
}
fn hash_128bit(values: &mut [Self], data: &[&[u8]]) {
Self::hash_128bit(values, data)
}
fn poly_mul(p: &[Self], q: &[Self]) -> Vec<Self> {
let p = p.iter().map(|x| x.inner).collect();
@@ -182,13 +201,20 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
const ZERO: Self = Self::ZERO;
const GENERATOR: Self = Self::GENERATOR;
const BYTE_SIZE: usize = Self::BYTE_SIZE;
type Affine = bls12_381::G1Affine;
fn projective(affine: Self::Affine) -> Self {
Self {
inner: affine.inner.into(),
}
}
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
self.mul_scalar(scalar)
}
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self {
Self::multi_mul_scalar(bases, scalars)
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
fn to_bytes(self) -> impl AsRef<[u8]> {
@@ -198,19 +224,32 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
fn double(self) -> Self {
self.double()
}
fn normalize(self) -> Self::Affine {
Self::Affine {
inner: self.inner.into_affine(),
}
}
}
impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
const ZERO: Self = Self::ZERO;
const GENERATOR: Self = Self::GENERATOR;
const BYTE_SIZE: usize = Self::BYTE_SIZE;
type Affine = bls12_381::G2Affine;
fn projective(affine: Self::Affine) -> Self {
Self {
inner: affine.inner.into(),
}
}
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
self.mul_scalar(scalar)
}
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self {
Self::multi_mul_scalar(bases, scalars)
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
fn to_bytes(self) -> impl AsRef<[u8]> {
@@ -220,6 +259,12 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
fn double(self) -> Self {
self.double()
}
fn normalize(self) -> Self::Affine {
Self::Affine {
inner: self.inner.into_affine(),
}
}
}
impl PairingGroupOps<bls12_381::Zp, bls12_381::G1, bls12_381::G2> for bls12_381::Gt {
@@ -254,6 +299,9 @@ impl FieldOps for bls12_446::Zp {
fn hash(values: &mut [Self], data: &[&[u8]]) {
Self::hash(values, data)
}
fn hash_128bit(values: &mut [Self], data: &[&[u8]]) {
Self::hash_128bit(values, data)
}
fn poly_mul(p: &[Self], q: &[Self]) -> Vec<Self> {
let p = p.iter().map(|x| x.inner).collect();
@@ -272,13 +320,21 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
const ZERO: Self = Self::ZERO;
const GENERATOR: Self = Self::GENERATOR;
const BYTE_SIZE: usize = Self::BYTE_SIZE;
type Affine = bls12_446::G1Affine;
fn projective(affine: Self::Affine) -> Self {
Self {
inner: affine.inner.into(),
}
}
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
self.mul_scalar(scalar)
}
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self {
Self::multi_mul_scalar(bases, scalars)
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
msm::msm_wnaf_g1_446(bases, scalars)
// Self::Affine::multi_mul_scalar(bases, scalars)
}
fn to_bytes(self) -> impl AsRef<[u8]> {
@@ -288,19 +344,32 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
fn double(self) -> Self {
self.double()
}
fn normalize(self) -> Self::Affine {
Self::Affine {
inner: self.inner.into_affine(),
}
}
}
impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
const ZERO: Self = Self::ZERO;
const GENERATOR: Self = Self::GENERATOR;
const BYTE_SIZE: usize = Self::BYTE_SIZE;
type Affine = bls12_446::G2Affine;
fn projective(affine: Self::Affine) -> Self {
Self {
inner: affine.inner.into(),
}
}
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
self.mul_scalar(scalar)
}
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self {
Self::multi_mul_scalar(bases, scalars)
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
Self::Affine::multi_mul_scalar(bases, scalars)
}
fn to_bytes(self) -> impl AsRef<[u8]> {
@@ -310,6 +379,12 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
fn double(self) -> Self {
self.double()
}
fn normalize(self) -> Self::Affine {
Self::Affine {
inner: self.inner.into_affine(),
}
}
}
impl PairingGroupOps<bls12_446::Zp, bls12_446::G1, bls12_446::G2> for bls12_446::Gt {

View File

@@ -36,6 +36,39 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] {
mod g1 {
use super::*;
#[derive(
Copy,
Clone,
Debug,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G1Affine {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
pub(crate) inner: ark_bls12_381::g1::G1Affine,
}
impl G1Affine {
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G1 {
inner: ark_bls12_381::g1::G1Projective::msm(
unsafe {
&*(bases as *const [G1Affine] as *const [ark_bls12_381::g1::G1Affine])
},
unsafe { &*(scalars as *const [Zp] as *const [ark_bls12_381::Fr]) },
)
.unwrap(),
}
}
}
#[derive(
Copy,
Clone,
@@ -179,6 +212,39 @@ mod g1 {
mod g2 {
use super::*;
#[derive(
Copy,
Clone,
Debug,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G2Affine {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
pub(crate) inner: ark_bls12_381::g2::G2Affine,
}
impl G2Affine {
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G2 {
inner: ark_bls12_381::g2::G2Projective::msm(
unsafe {
&*(bases as *const [G2Affine] as *const [ark_bls12_381::g2::G2Affine])
},
unsafe { &*(scalars as *const [Zp] as *const [ark_bls12_381::Fr]) },
)
.unwrap(),
}
}
}
#[derive(
Copy,
Clone,
@@ -193,7 +259,7 @@ mod g2 {
#[repr(transparent)]
pub struct G2 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
pub(super) inner: ark_bls12_381::G2Projective,
pub(crate) inner: ark_bls12_381::G2Projective,
}
impl fmt::Debug for G2 {
@@ -640,6 +706,25 @@ mod zp {
*value = Zp::from_raw_u64x6(bytes.map(u64::from_le_bytes));
}
}
pub fn hash_128bit(values: &mut [Zp], data: &[&[u8]]) {
use sha3::digest::{ExtendableOutput, Update, XofReader};
let mut hasher = sha3::Shake256::default();
for data in data {
hasher.update(data);
}
let mut reader = hasher.finalize_xof();
for value in values {
let mut bytes = [0u8; 2 * 8];
reader.read(&mut bytes);
let limbs: [u64; 2] = unsafe { core::mem::transmute(bytes) };
*value = Zp {
inner: BigInt([limbs[0], limbs[1], 0, 0]).into(),
};
}
}
}
impl Add for Zp {
@@ -714,8 +799,8 @@ mod zp {
}
}
pub use g1::G1;
pub use g2::G2;
pub use g1::{G1Affine, G1};
pub use g2::{G2Affine, G2};
pub use gt::Gt;
pub use zp::Zp;

View File

@@ -36,6 +36,39 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] {
mod g1 {
use super::*;
#[derive(
Copy,
Clone,
Debug,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G1Affine {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
pub(crate) inner: crate::curve_446::g1::G1Affine,
}
impl G1Affine {
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G1 {
inner: crate::curve_446::g1::G1Projective::msm(
unsafe {
&*(bases as *const [G1Affine] as *const [crate::curve_446::g1::G1Affine])
},
unsafe { &*(scalars as *const [Zp] as *const [crate::curve_446::Fr]) },
)
.unwrap(),
}
}
}
#[derive(
Copy,
Clone,
@@ -93,17 +126,17 @@ mod g1 {
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self {
use rayon::prelude::*;
let n_threads = rayon::current_num_threads();
let chunk_size = bases.len().div_ceil(n_threads);
bases
let bases = bases
.par_iter()
.map(|&x| x.inner.into_affine())
.chunks(chunk_size)
.zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size))
.map(|(bases, scalars)| Self {
inner: crate::curve_446::g1::G1Projective::msm(&bases, &scalars).unwrap(),
.collect::<Vec<_>>();
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
Self {
inner: crate::curve_446::g1::G1Projective::msm(&bases, unsafe {
&*(scalars as *const [Zp] as *const [crate::curve_446::Fr])
})
.sum::<Self>()
.unwrap(),
}
}
pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] {
@@ -178,6 +211,39 @@ mod g1 {
mod g2 {
use super::*;
#[derive(
Copy,
Clone,
Debug,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G2Affine {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
pub(crate) inner: crate::curve_446::g2::G2Affine,
}
impl G2Affine {
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
G2 {
inner: crate::curve_446::g2::G2Projective::msm(
unsafe {
&*(bases as *const [G2Affine] as *const [crate::curve_446::g2::G2Affine])
},
unsafe { &*(scalars as *const [Zp] as *const [crate::curve_446::Fr]) },
)
.unwrap(),
}
}
}
#[derive(
Copy,
Clone,
@@ -192,7 +258,7 @@ mod g2 {
#[repr(transparent)]
pub struct G2 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
pub(super) inner: crate::curve_446::g2::G2Projective,
pub(crate) inner: crate::curve_446::g2::G2Projective,
}
impl fmt::Debug for G2 {
@@ -672,7 +738,7 @@ mod zp {
#[repr(transparent)]
pub struct Zp {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
pub(crate) inner: crate::curve_446::Fr,
pub inner: crate::curve_446::Fr,
}
impl fmt::Debug for Zp {
@@ -772,6 +838,25 @@ mod zp {
*value = Zp::from_raw_u64x7(bytes.map(u64::from_le_bytes));
}
}
pub fn hash_128bit(values: &mut [Zp], data: &[&[u8]]) {
use sha3::digest::{ExtendableOutput, Update, XofReader};
let mut hasher = sha3::Shake256::default();
for data in data {
hasher.update(data);
}
let mut reader = hasher.finalize_xof();
for value in values {
let mut bytes = [0u8; 2 * 8];
reader.read(&mut bytes);
let limbs: [u64; 2] = unsafe { core::mem::transmute(bytes) };
*value = Zp {
inner: BigInt([limbs[0], limbs[1], 0, 0, 0]).into(),
};
}
}
}
impl Add for Zp {
@@ -846,8 +931,8 @@ mod zp {
}
}
pub use g1::G1;
pub use g2::G2;
pub use g1::{G1Affine, G1};
pub use g2::{G2Affine, G2};
pub use gt::Gt;
pub use zp::Zp;

View File

@@ -0,0 +1,442 @@
use ark_ec::short_weierstrass::Affine;
use ark_ec::AffineRepr;
use ark_ff::{AdditiveGroup, BigInt, BigInteger, Field, Fp, PrimeField};
use rayon::prelude::*;
fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
let scalar = a.as_ref();
let radix: u64 = 1 << w;
let window_mask: u64 = radix - 1;
let mut carry = 0u64;
let num_bits = if num_bits == 0 {
a.num_bits() as usize
} else {
num_bits
};
let digits_count = (num_bits + w - 1) / w;
(0..digits_count).map(move |i| {
// Construct a buffer of bits of the scalar, starting at `bit_offset`.
let bit_offset = i * w;
let u64_idx = bit_offset / 64;
let bit_idx = bit_offset % 64;
// Read the bits from the scalar
let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
// This window's bits are contained in a single u64,
// or it's the last u64 anyway.
scalar[u64_idx] >> bit_idx
} else {
// Combine the current u64's bits with the bits from the next u64
(scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx))
};
// Read the actual coefficient value from the window
let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)
// Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
carry = (coef + radix / 2) >> w;
let mut digit = (coef as i64) - (carry << w) as i64;
if i == digits_count - 1 {
digit += (carry << w) as i64;
}
digit
})
}
// Compute msm using windowed non-adjacent form
pub fn msm_wnaf_g1_446(
bases: &[super::bls12_446::G1Affine],
scalars: &[super::bls12_446::Zp],
) -> super::bls12_446::G1 {
use super::bls12_446::*;
let num_bits = 299usize;
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;
assert_eq!(bases.len(), scalars.len());
let size = bases.len();
let scalars = &*scalars
.par_iter()
.map(|x| x.inner.into_bigint())
.collect::<Vec<_>>();
let c = if size < 32 {
3
} else {
// natural log approx
(size.ilog2() as usize * 69 / 100) + 2
};
let digits_count = (num_bits + c - 1) / c;
let scalar_digits = scalars
.into_par_iter()
.flat_map_iter(|s| make_digits(s, c, num_bits))
.collect::<Vec<_>>();
let zero = G1Affine {
inner: Affine::zero(),
};
let window_sums: Vec<_> = (0..digits_count)
.into_par_iter()
.map(|i| {
let n = 1 << c;
let mut indices = vec![vec![]; n];
let mut d = vec![BaseField::ZERO; n + 1];
let mut e = vec![BaseField::ZERO; n + 1];
for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
use core::cmp::Ordering;
// digits is the digits thing of the first scalar?
let scalar = digits[i];
match 0.cmp(&scalar) {
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx),
Ordering::Equal => (),
}
}
let mut buckets = vec![zero; 1 << c];
loop {
d[0] = BaseField::ONE;
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
d[k + 1] = d[k] * a;
} else if value.inner.y == bucket.inner.y {
d[k + 1] = d[k] * value.inner.y.double();
} else {
d[k + 1] = d[k];
}
continue;
}
}
d[k + 1] = d[k];
}
e[n] = d[n].inverse().unwrap();
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices)
.enumerate()
.rev()
{
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
e[k] = e[k + 1] * a;
} else if value.inner.y == bucket.inner.y {
e[k] = e[k + 1] * value.inner.y.double();
} else {
e[k] = e[k + 1];
}
continue;
}
}
e[k] = e[k + 1];
}
let d = &d[..n];
let e = &e[1..];
let mut empty = true;
for ((&d, &e), (bucket, idx)) in core::iter::zip(
core::iter::zip(d, e),
core::iter::zip(&mut buckets, &mut indices),
) {
empty &= idx.len() <= 1;
if let Some(idx) = idx.pop() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let x1 = bucket.inner.x;
let x2 = value.inner.x;
let y1 = bucket.inner.y;
let y2 = value.inner.y;
let eq_x = x1 == x2;
if eq_x && y1 != y2 {
bucket.inner.infinity = true;
} else {
let r = d * e;
let m = if eq_x {
let x1 = x1.square();
x1 + x1.double()
} else {
y2 - y1
};
let m = m * r;
let x3 = m.square() - x1 - x2;
let y3 = m * (x1 - x3) - y1;
bucket.inner.x = x3;
bucket.inner.y = y3;
}
} else {
*bucket = value;
}
}
}
if empty {
break;
}
}
let mut running_sum = G1::ZERO;
let mut res = G1::ZERO;
buckets.into_iter().rev().for_each(|b| {
running_sum.inner += b.inner;
res += running_sum;
});
res
})
.collect();
// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();
// We're traversing windows from high to low.
lowest
+ window_sums[1..]
.iter()
.rev()
.fold(G1::ZERO, |mut total, &sum_i| {
total += sum_i;
for _ in 0..c {
total = total.double();
}
total
})
}
// Compute msm using windowed non-adjacent form
pub fn msm_wnaf_g1_446_extended(
bases: &[super::bls12_446::G1Affine],
scalars: &[super::bls12_446::Zp],
) -> super::bls12_446::G1 {
use super::bls12_446::*;
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;
// let num_bits = 75usize;
// let mask = BigInt([!0, (1 << 11) - 1, 0, 0, 0]);
// let scalars = &*scalars
// .par_iter()
// .map(|x| x.inner.into_bigint())
// .flat_map_iter(|x| (0..4).map(move |i| (x >> (75 * i)) & mask))
// .collect::<Vec<_>>();
let num_bits = 150usize;
let mask = BigInt([!0, !0, (1 << 22) - 1, 0, 0]);
let scalars = &*scalars
.par_iter()
.map(|x| x.inner.into_bigint())
.flat_map_iter(|x| (0..2).map(move |i| (x >> (150 * i)) & mask))
.collect::<Vec<_>>();
assert_eq!(bases.len(), scalars.len());
let size = bases.len();
let c = if size < 32 {
3
} else {
// natural log approx
(size.ilog2() as usize * 69 / 100) + 2
};
let c = c - 3;
let digits_count = (num_bits + c - 1) / c;
let scalar_digits = scalars
.into_par_iter()
.flat_map_iter(|s| make_digits(s, c, num_bits))
.collect::<Vec<_>>();
let zero = G1Affine {
inner: Affine::zero(),
};
let window_sums: Vec<_> = (0..digits_count)
.into_par_iter()
.map(|i| {
let n = 1 << c;
let mut indices = vec![vec![]; n];
let mut d = vec![BaseField::ZERO; n + 1];
let mut e = vec![BaseField::ZERO; n + 1];
for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
use core::cmp::Ordering;
// digits is the digits thing of the first scalar?
let scalar = digits[i];
match 0.cmp(&scalar) {
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx),
Ordering::Equal => (),
}
}
let mut buckets = vec![zero; 1 << c];
loop {
d[0] = BaseField::ONE;
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
d[k + 1] = d[k] * a;
} else if value.inner.y == bucket.inner.y {
d[k + 1] = d[k] * value.inner.y.double();
} else {
d[k + 1] = d[k];
}
continue;
}
}
d[k + 1] = d[k];
}
e[n] = d[n].inverse().unwrap();
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices)
.enumerate()
.rev()
{
if let Some(idx) = idx.last().copied() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let a = value.inner.x - bucket.inner.x;
if a != BaseField::ZERO {
e[k] = e[k + 1] * a;
} else if value.inner.y == bucket.inner.y {
e[k] = e[k + 1] * value.inner.y.double();
} else {
e[k] = e[k + 1];
}
continue;
}
}
e[k] = e[k + 1];
}
let d = &d[..n];
let e = &e[1..];
let mut empty = true;
for ((&d, &e), (bucket, idx)) in core::iter::zip(
core::iter::zip(d, e),
core::iter::zip(&mut buckets, &mut indices),
) {
empty &= idx.len() <= 1;
if let Some(idx) = idx.pop() {
let value = if idx >> (usize::BITS - 1) == 1 {
let mut val = bases[!idx];
val.inner.y = -val.inner.y;
val
} else {
bases[idx]
};
if !bucket.inner.infinity {
let x1 = bucket.inner.x;
let x2 = value.inner.x;
let y1 = bucket.inner.y;
let y2 = value.inner.y;
let eq_x = x1 == x2;
if eq_x && y1 != y2 {
bucket.inner.infinity = true;
} else {
let r = d * e;
let m = if eq_x {
let x1 = x1.square();
x1 + x1.double()
} else {
y2 - y1
};
let m = m * r;
let x3 = m.square() - x1 - x2;
let y3 = m * (x1 - x3) - y1;
bucket.inner.x = x3;
bucket.inner.y = y3;
}
} else {
*bucket = value;
}
}
}
if empty {
break;
}
}
let mut running_sum = G1::ZERO;
let mut res = G1::ZERO;
buckets.into_iter().rev().for_each(|b| {
running_sum.inner += b.inner;
res += running_sum;
});
res
})
.collect();
// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();
// We're traversing windows from high to low.
lowest
+ window_sums[1..]
.iter()
.rev()
.fold(G1::ZERO, |mut total, &sum_i| {
total += sum_i;
for _ in 0..c {
total = total.double();
}
total
})
}

View File

@@ -6,7 +6,10 @@ pub struct PublicParams<G: Curve> {
}
impl<G: Curve> PublicParams<G> {
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
pub fn from_vec(
g_list: Vec<Affine<G::Zp, G::G1>>,
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
) -> Self {
Self {
g_lists: GroupElements::from_vec(g_list, g_hat_list),
}
@@ -57,7 +60,7 @@ pub fn commit<G: Curve>(
let mut c_hat = g_hat.mul_scalar(gamma);
for j in 1..n + 1 {
let term = if x[j] != 0 {
public.g_lists.g_hat_list[j]
G::G2::projective(public.g_lists.g_hat_list[j])
} else {
G::G2::ZERO
};
@@ -91,7 +94,7 @@ pub fn prove<G: Curve>(
let mut c_y = g.mul_scalar(gamma_y);
for j in 1..n + 1 {
c_y += (g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x[j]));
c_y += (G::G1::projective(g_list[n + 1 - j])).mul_scalar(y[j] * G::Zp::from_u64(x[j]));
}
let y_bytes = &*(1..n + 1)
@@ -143,7 +146,7 @@ pub fn prove<G: Curve>(
let mut proof = g.mul_scalar(poly[0]);
for i in 1..poly.len() {
proof += g_list[i].mul_scalar(poly[i]);
proof += G::G1::projective(g_list[i]).mul_scalar(poly[i]);
}
proof
};
@@ -190,7 +193,7 @@ pub fn verify<G: Curve>(
let numerator = {
let mut p = c_y.mul_scalar(delta_y);
for i in 1..n + 1 {
let gy = g_list[n + 1 - i].mul_scalar(y[i]);
let gy = G::G1::projective(g_list[n + 1 - i]).mul_scalar(y[i]);
p += gy.mul_scalar(delta_eq).mul_scalar(t[i]) - gy.mul_scalar(delta_y);
}
e(p, c_hat)
@@ -198,7 +201,9 @@ pub fn verify<G: Curve>(
let denominator = {
let mut q = G::G2::ZERO;
for i in 1..n + 1 {
q += g_hat_list[i].mul_scalar(delta_eq).mul_scalar(t[i]);
q += G::G2::projective(g_hat_list[i])
.mul_scalar(delta_eq)
.mul_scalar(t[i]);
}
e(c_y, q)
};

View File

@@ -6,7 +6,10 @@ pub struct PublicParams<G: Curve> {
}
impl<G: Curve> PublicParams<G> {
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
pub fn from_vec(
g_list: Vec<Affine<G::Zp, G::G1>>,
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
) -> Self {
Self {
g_lists: GroupElements::from_vec(g_list, g_hat_list),
}
@@ -55,7 +58,7 @@ pub fn commit<G: Curve>(
let mut c = g.mul_scalar(gamma);
for j in 1..n + 1 {
let term = public.g_lists.g_list[j].mul_scalar(G::Zp::from_u64(m[j]));
let term = G::G1::projective(public.g_lists.g_list[j]).mul_scalar(G::Zp::from_u64(m[j]));
c += term;
}
@@ -80,11 +83,11 @@ pub fn prove<G: Curve>(
let gamma = private.gamma;
let g_list = &public.0.g_lists.g_list;
let mut pi = g_list[n + 1 - i].mul_scalar(gamma);
let mut pi = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma);
for j in 1..n + 1 {
if i != j {
let term = if m[j] & 1 == 1 {
g_list[n + 1 - i + j]
G::G1::projective(g_list[n + 1 - i + j])
} else {
G::G1::ZERO
};
@@ -110,8 +113,13 @@ pub fn verify<G: Curve>(
let n = public.0.g_lists.message_len;
let i = index + 1;
let lhs = e(c, g_hat_list[n + 1 - i]);
let rhs = e(proof.pi, g_hat) + (e(g_list[1], g_hat_list[n])).mul_scalar(G::Zp::from_u64(mi));
let lhs = e(c, G::G2::projective(g_hat_list[n + 1 - i]));
let rhs = e(proof.pi, g_hat)
+ (e(
G::G1::projective(g_list[1]),
G::G2::projective(g_hat_list[n]),
))
.mul_scalar(G::Zp::from_u64(mi));
if lhs == rhs {
Ok(())

View File

@@ -75,6 +75,8 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
}
}
pub type Affine<Zp, Group> = <Group as CurveGroupOps<Zp>>::Affine;
#[derive(
Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize,
)]
@@ -83,8 +85,8 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
))]
struct GroupElements<G: Curve> {
g_list: OneBased<Vec<G::G1>>,
g_hat_list: OneBased<Vec<G::G2>>,
g_list: OneBased<Vec<Affine<G::Zp, G::G1>>>,
g_hat_list: OneBased<Vec<Affine<G::Zp, G::G2>>>,
message_len: usize,
}
@@ -98,9 +100,9 @@ impl<G: Curve> GroupElements<G> {
for i in 0..2 * message_len {
if i == message_len {
g_list.push(G::G1::ZERO);
g_list.push(G::G1::ZERO.normalize());
} else {
g_list.push(g_cur);
g_list.push(g_cur.normalize());
}
g_cur = g_cur.mul_scalar(alpha);
}
@@ -111,22 +113,22 @@ impl<G: Curve> GroupElements<G> {
let mut g_hat_list = Vec::with_capacity(message_len);
let mut g_hat_cur = G::G2::GENERATOR.mul_scalar(alpha);
for _ in 0..message_len {
g_hat_list.push(g_hat_cur);
g_hat_cur = (g_hat_cur).mul_scalar(alpha);
g_hat_list.push(g_hat_cur.normalize());
g_hat_cur = g_hat_cur.mul_scalar(alpha);
}
g_hat_list
},
);
Self {
g_list: OneBased::new(g_list),
g_hat_list: OneBased::new(g_hat_list),
message_len,
}
Self::from_vec(g_list, g_hat_list)
}
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
pub fn from_vec(
g_list: Vec<Affine<G::Zp, G::G1>>,
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
) -> Self {
let message_len = g_hat_list.len();
Self {
g_list: OneBased::new(g_list),
g_hat_list: OneBased::new(g_hat_list),

View File

@@ -1,3 +1,6 @@
// TODO: refactor copy-pasted code in proof/verify
// TODO: ask about metadata in hashing functions
use super::*;
use core::marker::PhantomData;
use rayon::prelude::*;
@@ -22,8 +25,8 @@ pub struct PublicParams<G: Curve> {
impl<G: Curve> PublicParams<G> {
#[allow(clippy::too_many_arguments)]
pub fn from_vec(
g_list: Vec<G::G1>,
g_hat_list: Vec<G::G2>,
g_list: Vec<Affine<G::Zp, G::G1>>,
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
big_d: usize,
n: usize,
d: usize,
@@ -246,12 +249,12 @@ pub fn prove<G: Curve>(
let mut dot = 0i128;
for j in 0..d {
let b = if i + j < d {
b[d - j - i - 1]
b[d - j - i - 1] as i128
} else {
b[2 * d - j - i - 1].wrapping_neg()
-(b[2 * d - j - i - 1] as i128)
};
dot += r[d - j - 1] as i128 * b as i128;
dot += r[d - j - 1] as i128 * b;
}
*r2 += dot;
@@ -293,8 +296,9 @@ pub fn prove<G: Curve>(
let mut c_hat = g_hat.mul_scalar(gamma);
for j in 1..big_d + 1 {
let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO };
c_hat += term;
if w[j] {
c_hat += G::G2::projective(g_hat_list[j]);
}
}
let x_bytes = &*[
@@ -337,7 +341,7 @@ pub fn prove<G: Curve>(
compute_a_theta::<G>(theta0, d, a, k, b, &mut a_theta, t, delta, b_i, b_r, q);
let mut t = vec![G::Zp::ZERO; n];
G::Zp::hash(
G::Zp::hash_128bit(
&mut t,
&[
&(1..n + 1)
@@ -356,6 +360,7 @@ pub fn prove<G: Curve>(
&[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()],
);
let [delta_eq, delta_y] = delta;
let delta = [delta_eq, delta_y, delta_theta];
let mut poly_0 = vec![G::Zp::ZERO; n + 1];
let mut poly_1 = vec![G::Zp::ZERO; big_d + 1];
@@ -394,10 +399,11 @@ pub fn prove<G: Curve>(
t_theta += theta0[d + i] * G::Zp::from_i64(c2[i]);
}
let mut poly = G::Zp::poly_sub(
&G::Zp::poly_mul(&poly_0, &poly_1),
&G::Zp::poly_mul(&poly_2, &poly_3),
let mul = rayon::join(
|| G::Zp::poly_mul(&poly_0, &poly_1),
|| G::Zp::poly_mul(&poly_2, &poly_3),
);
let mut poly = G::Zp::poly_sub(&mul.0, &mul.1);
if poly.len() > n + 1 {
poly[n + 1] -= t_theta * delta_theta;
}
@@ -418,6 +424,7 @@ pub fn prove<G: Curve>(
}
})
.collect::<Vec<_>>();
let c_h = G::G1::multi_mul_scalar(&g_list.0[..n], &scalars);
let mut z = G::Zp::ZERO;
@@ -497,6 +504,7 @@ pub fn prove<G: Curve>(
}
let mut q = vec![G::Zp::ZERO; n];
// https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode
for i in (0..n).rev() {
poly[i] = poly[i] + z * poly[i + 1];
q[i] = poly[i + 1];
@@ -570,6 +578,10 @@ fn compute_a_theta<G: Curve>(
let theta1 = &theta0[..d];
let theta2 = &theta0[d..];
let a = a.iter().map(|x| G::Zp::from_i64(*x)).collect::<Vec<_>>();
let b = b.iter().map(|x| G::Zp::from_i64(*x)).collect::<Vec<_>>();
{
let a_theta = &mut a_theta[..d];
a_theta
@@ -579,27 +591,24 @@ fn compute_a_theta<G: Curve>(
let mut dot = G::Zp::ZERO;
for j in 0..d {
let a = if i <= j {
a[j - i]
if i <= j {
dot += a[j - i] * theta1[j];
} else {
a[d + j - i].wrapping_neg()
};
dot += G::Zp::from_i64(a) * theta1[j];
dot -= a[(d + j) - i] * theta1[j];
}
}
for j in 0..k {
let b = if i + j < d {
b[d - i - j - 1]
if i + j < d {
dot += b[d - i - j - 1] * theta2[j];
} else {
b[2 * d - i - j - 1].wrapping_neg()
dot -= b[2 * d - i - j - 1] * theta2[j];
};
dot += G::Zp::from_i64(b) * theta2[j];
}
*a_theta_i = dot;
});
}
let a_theta = &mut a_theta[d..];
let step = t.ilog2() as usize;
@@ -726,7 +735,7 @@ pub fn verify<G: Curve>(
}
let mut t = vec![G::Zp::ZERO; n];
G::Zp::hash(
G::Zp::hash_128bit(
&mut t,
&[
&(1..n + 1)
@@ -745,6 +754,7 @@ pub fn verify<G: Curve>(
&[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()],
);
let [delta_eq, delta_y] = delta;
let delta = [delta_eq, delta_y, delta_theta];
if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) {
let mut z = G::Zp::ZERO;
@@ -789,7 +799,11 @@ pub fn verify<G: Curve>(
if e(pi, G::G2::GENERATOR)
!= e(c_y.mul_scalar(delta_y) + c_h, c_hat)
- e(c_y.mul_scalar(delta_eq), c_hat_t)
- e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta)
- e(
G::G1::projective(g_list[1]),
G::G2::projective(g_hat_list[n]),
)
.mul_scalar(t_theta * delta_theta)
{
return Err(());
}
@@ -822,13 +836,17 @@ pub fn verify<G: Curve>(
if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR)
+ e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w)
== e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z))
== e(
pi_kzg,
G::G2::projective(g_hat_list[1]) - G::G2::GENERATOR.mul_scalar(z),
)
{
Ok(())
} else {
Err(())
}
} else {
// PERF: rewrite as multi_mul_scalar?
let (term0, term1) = rayon::join(
|| {
let p = c_y.mul_scalar(delta_y)
@@ -839,7 +857,7 @@ pub fn verify<G: Curve>(
if i < big_d + 1 {
factor += delta_theta * a_theta[i - 1];
}
g_list[n + 1 - i].mul_scalar(factor)
G::G1::projective(g_list[n + 1 - i]).mul_scalar(factor)
})
.sum::<G::G1>();
let q = c_hat;
@@ -849,14 +867,14 @@ pub fn verify<G: Curve>(
let p = c_y;
let q = (1..n + 1)
.into_par_iter()
.map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i]))
.map(|i| G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i]))
.sum::<G::G2>();
e(p, q)
},
);
let term2 = {
let p = g_list[1];
let q = g_hat_list[n];
let p = G::G1::projective(g_list[1]);
let q = G::G2::projective(g_hat_list[n]);
e(p, q)
};

View File

@@ -6,7 +6,10 @@ pub struct PublicParams<G: Curve> {
}
impl<G: Curve> PublicParams<G> {
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
pub fn from_vec(
g_list: Vec<Affine<G::Zp, G::G1>>,
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
) -> Self {
Self {
g_lists: GroupElements::from_vec(g_list, g_hat_list),
}
@@ -54,7 +57,8 @@ pub fn commit<G: Curve>(
let g_hat = G::G2::GENERATOR;
let r = G::Zp::rand(rng);
let v_hat = g_hat.mul_scalar(r) + public.g_lists.g_hat_list[1].mul_scalar(G::Zp::from_u64(x));
let v_hat = g_hat.mul_scalar(r)
+ G::G2::projective(public.g_lists.g_hat_list[1]).mul_scalar(G::Zp::from_u64(x));
(PublicCommit { l, v_hat }, PrivateCommit { x, r })
}
@@ -87,7 +91,7 @@ pub fn prove<G: Curve>(
let mut c = g_hat.mul_scalar(gamma);
for j in 1..l + 1 {
let term = if x_bits[j] != 0 {
g_hat_list[j]
G::G2::projective(g_hat_list[j])
} else {
G::G2::ZERO
};
@@ -96,13 +100,13 @@ pub fn prove<G: Curve>(
c
};
let mut proof_x = -g_list[n].mul_scalar(r);
let mut proof_x = -G::G1::projective(g_list[n]).mul_scalar(r);
for i in 1..l + 1 {
let mut term = g_list[n + 1 - i].mul_scalar(gamma);
let mut term = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma);
for j in 1..l + 1 {
if j != i {
let term_inner = if x_bits[j] != 0 {
g_list[n + 1 - i + j]
G::G1::projective(g_list[n + 1 - i + j])
} else {
G::G1::ZERO
};
@@ -124,7 +128,7 @@ pub fn prove<G: Curve>(
let y = OneBased(y);
let mut c_y = g.mul_scalar(gamma_y);
for j in 1..l + 1 {
c_y += g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
c_y += G::G1::projective(g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
}
let y_bytes = &*(1..n + 1)
@@ -145,11 +149,11 @@ pub fn prove<G: Curve>(
let mut proof_eq = G::G1::ZERO;
for i in 1..n + 1 {
let mut numerator = g_list[n + 1 - i].mul_scalar(gamma);
let mut numerator = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma);
for j in 1..n + 1 {
if j != i {
let term = if x_bits[j] != 0 {
g_list[n + 1 - i + j]
G::G1::projective(g_list[n + 1 - i + j])
} else {
G::G1::ZERO
};
@@ -158,10 +162,11 @@ pub fn prove<G: Curve>(
}
numerator = numerator.mul_scalar(t[i] * y[i]);
let mut denominator = g_list[i].mul_scalar(gamma_y);
let mut denominator = G::G1::projective(g_list[i]).mul_scalar(gamma_y);
for j in 1..n + 1 {
if j != i {
denominator += g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
denominator += G::G1::projective(g_list[n + 1 - j + i])
.mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
}
}
denominator = denominator.mul_scalar(t[i]);
@@ -171,14 +176,16 @@ pub fn prove<G: Curve>(
let mut proof_y = g.mul_scalar(gamma_y);
for j in 1..n + 1 {
proof_y -= g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
proof_y -=
G::G1::projective(g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
}
proof_y = proof_y.mul_scalar(gamma);
for i in 1..n + 1 {
let mut term = g_list[i].mul_scalar(gamma_y);
let mut term = G::G1::projective(g_list[i]).mul_scalar(gamma_y);
for j in 1..n + 1 {
if j != i {
term -= g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
term -= G::G1::projective(g_list[n + 1 - j + i])
.mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
}
}
let term = if x_bits[i] != 0 { term } else { G::G1::ZERO };
@@ -202,7 +209,8 @@ pub fn prove<G: Curve>(
let mut proof_v = G::G1::ZERO;
for i in 2..n + 1 {
proof_v += G::G1::mul_scalar(
g_list[n + 1 - i].mul_scalar(r) + g_list[n + 2 - i].mul_scalar(G::Zp::from_u64(x)),
G::G1::projective(g_list[n + 1 - i]).mul_scalar(r)
+ G::G1::projective(g_list[n + 2 - i]).mul_scalar(G::Zp::from_u64(x)),
s[i],
);
}
@@ -300,7 +308,7 @@ pub fn verify<G: Curve>(
let numerator = {
let mut p = c_y.mul_scalar(delta_y);
for i in 1..n + 1 {
let g = g_list[n + 1 - i];
let g = G::G1::projective(g_list[n + 1 - i]);
if i <= l {
p += g.mul_scalar(delta_x * G::Zp::from_u64(1 << (i - 1)));
}
@@ -309,16 +317,16 @@ pub fn verify<G: Curve>(
e(p, c_hat)
};
let denominator_0 = {
let mut p = g_list[n].mul_scalar(delta_x);
let mut p = G::G1::projective(g_list[n]).mul_scalar(delta_x);
for i in 2..n + 1 {
p -= g_list[n + 1 - i].mul_scalar(delta_v * s[i]);
p -= G::G1::projective(g_list[n + 1 - i]).mul_scalar(delta_v * s[i]);
}
e(p, v_hat)
};
let denominator_1 = {
let mut q = G::G2::ZERO;
for i in 1..n + 1 {
q += g_hat_list[i].mul_scalar(delta_eq * t[i]);
q += G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i]);
}
e(c_y, q)
};

View File

@@ -19,8 +19,8 @@ pub struct PublicParams<G: Curve> {
impl<G: Curve> PublicParams<G> {
pub fn from_vec(
g_list: Vec<G::G1>,
g_hat_list: Vec<G::G2>,
g_list: Vec<Affine<G::Zp, G::G1>>,
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
d: usize,
big_n: usize,
big_m: usize,
@@ -268,7 +268,11 @@ pub fn prove<G: Curve>(
let mut c_hat = g_hat.mul_scalar(gamma);
for j in 1..big_d + 1 {
let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO };
let term = if w[j] {
G::G2::projective(g_hat_list[j])
} else {
G::G2::ZERO
};
c_hat += term;
}
@@ -763,7 +767,11 @@ pub fn verify<G: Curve>(
if e(pi, G::G2::GENERATOR)
!= e(c_y.mul_scalar(delta_y) + c_h, c_hat)
- e(c_y.mul_scalar(delta_eq), c_hat_t)
- e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta)
- e(
G::G1::projective(g_list[1]),
G::G2::projective(g_hat_list[n]),
)
.mul_scalar(t_theta * delta_theta)
{
return Err(());
}
@@ -796,7 +804,10 @@ pub fn verify<G: Curve>(
if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR)
+ e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w)
== e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z))
== e(
pi_kzg,
G::G2::projective(g_hat_list[1]) - G::G2::GENERATOR.mul_scalar(z),
)
{
Ok(())
} else {
@@ -813,7 +824,7 @@ pub fn verify<G: Curve>(
if i < big_d + 1 {
factor += delta_theta * a_theta[i - 1];
}
g_list[n + 1 - i].mul_scalar(factor)
G::G1::projective(g_list[n + 1 - i]).mul_scalar(factor)
})
.sum::<G::G1>();
let q = c_hat;
@@ -823,14 +834,14 @@ pub fn verify<G: Curve>(
let p = c_y;
let q = (1..n + 1)
.into_par_iter()
.map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i]))
.map(|i| G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i]))
.sum::<G::G2>();
e(p, q)
},
);
let term2 = {
let p = g_list[1];
let q = g_hat_list[n];
let p = G::G1::projective(g_list[1]);
let q = G::G2::projective(g_hat_list[n]);
e(p, q)
};