mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(zk): improve performance of zk pke proofs
This commit is contained in:
@@ -12,15 +12,15 @@ description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryp
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
ark-bls12-381 = "0.4.0"
|
||||
ark-ec = "0.4.2"
|
||||
ark-ff = "0.4.2"
|
||||
ark-poly = "0.4.2"
|
||||
ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" }
|
||||
ark-ec = { package = "tfhe-ark-ec", version = "0.4.2" }
|
||||
ark-ff = { package = "tfhe-ark-ff", version = "0.4.3" }
|
||||
ark-poly = { package = "tfhe-ark-poly", version = "0.4.2" }
|
||||
ark-serialize = { version = "0.4.2" }
|
||||
rand = "0.8.5"
|
||||
rayon = "1.8.0"
|
||||
sha3 = "0.10.8"
|
||||
serde = { version = "~1.0", features = ["derive"] }
|
||||
ark-serialize = { version = "0.4.2" }
|
||||
zeroize = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -466,8 +466,8 @@ pub mod g1 {
|
||||
use ark_ec::bls12::Bls12Config;
|
||||
use ark_ec::models::CurveConfig;
|
||||
use ark_ec::short_weierstrass::{Affine, SWCurveConfig};
|
||||
use ark_ec::{bls12, AffineRepr, Group};
|
||||
use ark_ff::{Field, MontFp, One, PrimeField, Zero};
|
||||
use ark_ec::{bls12, AdditiveGroup, AffineRepr, PrimeGroup};
|
||||
use ark_ff::{MontFp, One, PrimeField, Zero};
|
||||
use ark_serialize::{Compress, SerializationError};
|
||||
use core::ops::Neg;
|
||||
|
||||
@@ -631,7 +631,7 @@ pub mod g2 {
|
||||
use ark_ec::models::CurveConfig;
|
||||
use ark_ec::short_weierstrass::{Affine, SWCurveConfig};
|
||||
use ark_ec::{bls12, AffineRepr};
|
||||
use ark_ff::{MontFp, Zero};
|
||||
use ark_ff::MontFp;
|
||||
use ark_serialize::{Compress, SerializationError};
|
||||
|
||||
pub type G2Affine = bls12::G2Affine<super::Config>;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use ark_ec::{CurveGroup, Group, VariableBaseMSM};
|
||||
use ark_ec::{AdditiveGroup as Group, CurveGroup, VariableBaseMSM};
|
||||
use ark_ff::{BigInt, Field, MontFp, Zero};
|
||||
use ark_poly::univariate::DensePolynomial;
|
||||
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
|
||||
@@ -37,6 +37,8 @@ impl<T: fmt::Display + PartialEq + Field> fmt::Debug for MontIntDisplay<'_, T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub mod msm;
|
||||
|
||||
pub mod bls12_381;
|
||||
pub mod bls12_446;
|
||||
|
||||
@@ -44,6 +46,7 @@ pub trait FieldOps:
|
||||
Copy
|
||||
+ Send
|
||||
+ Sync
|
||||
+ core::fmt::Debug
|
||||
+ core::ops::AddAssign<Self>
|
||||
+ core::ops::SubAssign<Self>
|
||||
+ core::ops::Add<Self, Output = Self>
|
||||
@@ -62,6 +65,7 @@ pub trait FieldOps:
|
||||
fn to_bytes(self) -> impl AsRef<[u8]>;
|
||||
fn rand(rng: &mut dyn rand::RngCore) -> Self;
|
||||
fn hash(values: &mut [Self], data: &[&[u8]]);
|
||||
fn hash_128bit(values: &mut [Self], data: &[&[u8]]);
|
||||
fn poly_mul(p: &[Self], q: &[Self]) -> Vec<Self>;
|
||||
fn poly_sub(p: &[Self], q: &[Self]) -> Vec<Self> {
|
||||
use core::iter::zip;
|
||||
@@ -113,10 +117,22 @@ pub trait CurveGroupOps<Zp>:
|
||||
const GENERATOR: Self;
|
||||
const BYTE_SIZE: usize;
|
||||
|
||||
type Affine: Copy
|
||||
+ Send
|
||||
+ Sync
|
||||
+ core::fmt::Debug
|
||||
+ serde::Serialize
|
||||
+ for<'de> serde::Deserialize<'de>
|
||||
+ CanonicalSerialize
|
||||
+ CanonicalDeserialize;
|
||||
|
||||
fn projective(affine: Self::Affine) -> Self;
|
||||
|
||||
fn mul_scalar(self, scalar: Zp) -> Self;
|
||||
fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self;
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[Zp]) -> Self;
|
||||
fn to_bytes(self) -> impl AsRef<[u8]>;
|
||||
fn double(self) -> Self;
|
||||
fn normalize(self) -> Self::Affine;
|
||||
}
|
||||
|
||||
pub trait PairingGroupOps<Zp, G1, G2>:
|
||||
@@ -164,6 +180,9 @@ impl FieldOps for bls12_381::Zp {
|
||||
fn hash(values: &mut [Self], data: &[&[u8]]) {
|
||||
Self::hash(values, data)
|
||||
}
|
||||
fn hash_128bit(values: &mut [Self], data: &[&[u8]]) {
|
||||
Self::hash_128bit(values, data)
|
||||
}
|
||||
|
||||
fn poly_mul(p: &[Self], q: &[Self]) -> Vec<Self> {
|
||||
let p = p.iter().map(|x| x.inner).collect();
|
||||
@@ -182,13 +201,20 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
|
||||
const ZERO: Self = Self::ZERO;
|
||||
const GENERATOR: Self = Self::GENERATOR;
|
||||
const BYTE_SIZE: usize = Self::BYTE_SIZE;
|
||||
type Affine = bls12_381::G1Affine;
|
||||
|
||||
fn projective(affine: Self::Affine) -> Self {
|
||||
Self {
|
||||
inner: affine.inner.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
|
||||
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self {
|
||||
Self::multi_mul_scalar(bases, scalars)
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
|
||||
Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
|
||||
fn to_bytes(self) -> impl AsRef<[u8]> {
|
||||
@@ -198,19 +224,32 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
|
||||
fn double(self) -> Self {
|
||||
self.double()
|
||||
}
|
||||
|
||||
fn normalize(self) -> Self::Affine {
|
||||
Self::Affine {
|
||||
inner: self.inner.into_affine(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
|
||||
const ZERO: Self = Self::ZERO;
|
||||
const GENERATOR: Self = Self::GENERATOR;
|
||||
const BYTE_SIZE: usize = Self::BYTE_SIZE;
|
||||
type Affine = bls12_381::G2Affine;
|
||||
|
||||
fn projective(affine: Self::Affine) -> Self {
|
||||
Self {
|
||||
inner: affine.inner.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_381::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
|
||||
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self {
|
||||
Self::multi_mul_scalar(bases, scalars)
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self {
|
||||
Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
|
||||
fn to_bytes(self) -> impl AsRef<[u8]> {
|
||||
@@ -220,6 +259,12 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
|
||||
fn double(self) -> Self {
|
||||
self.double()
|
||||
}
|
||||
|
||||
fn normalize(self) -> Self::Affine {
|
||||
Self::Affine {
|
||||
inner: self.inner.into_affine(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PairingGroupOps<bls12_381::Zp, bls12_381::G1, bls12_381::G2> for bls12_381::Gt {
|
||||
@@ -254,6 +299,9 @@ impl FieldOps for bls12_446::Zp {
|
||||
fn hash(values: &mut [Self], data: &[&[u8]]) {
|
||||
Self::hash(values, data)
|
||||
}
|
||||
fn hash_128bit(values: &mut [Self], data: &[&[u8]]) {
|
||||
Self::hash_128bit(values, data)
|
||||
}
|
||||
|
||||
fn poly_mul(p: &[Self], q: &[Self]) -> Vec<Self> {
|
||||
let p = p.iter().map(|x| x.inner).collect();
|
||||
@@ -272,13 +320,21 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
|
||||
const ZERO: Self = Self::ZERO;
|
||||
const GENERATOR: Self = Self::GENERATOR;
|
||||
const BYTE_SIZE: usize = Self::BYTE_SIZE;
|
||||
type Affine = bls12_446::G1Affine;
|
||||
|
||||
fn projective(affine: Self::Affine) -> Self {
|
||||
Self {
|
||||
inner: affine.inner.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
|
||||
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self {
|
||||
Self::multi_mul_scalar(bases, scalars)
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
|
||||
msm::msm_wnaf_g1_446(bases, scalars)
|
||||
// Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
|
||||
fn to_bytes(self) -> impl AsRef<[u8]> {
|
||||
@@ -288,19 +344,32 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
|
||||
fn double(self) -> Self {
|
||||
self.double()
|
||||
}
|
||||
|
||||
fn normalize(self) -> Self::Affine {
|
||||
Self::Affine {
|
||||
inner: self.inner.into_affine(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
|
||||
const ZERO: Self = Self::ZERO;
|
||||
const GENERATOR: Self = Self::GENERATOR;
|
||||
const BYTE_SIZE: usize = Self::BYTE_SIZE;
|
||||
type Affine = bls12_446::G2Affine;
|
||||
|
||||
fn projective(affine: Self::Affine) -> Self {
|
||||
Self {
|
||||
inner: affine.inner.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_scalar(self, scalar: bls12_446::Zp) -> Self {
|
||||
self.mul_scalar(scalar)
|
||||
}
|
||||
|
||||
fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self {
|
||||
Self::multi_mul_scalar(bases, scalars)
|
||||
fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self {
|
||||
Self::Affine::multi_mul_scalar(bases, scalars)
|
||||
}
|
||||
|
||||
fn to_bytes(self) -> impl AsRef<[u8]> {
|
||||
@@ -310,6 +379,12 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
|
||||
fn double(self) -> Self {
|
||||
self.double()
|
||||
}
|
||||
|
||||
fn normalize(self) -> Self::Affine {
|
||||
Self::Affine {
|
||||
inner: self.inner.into_affine(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PairingGroupOps<bls12_446::Zp, bls12_446::G1, bls12_446::G2> for bls12_446::Gt {
|
||||
|
||||
@@ -36,6 +36,39 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] {
|
||||
mod g1 {
|
||||
use super::*;
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G1Affine {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
pub(crate) inner: ark_bls12_381::g1::G1Affine,
|
||||
}
|
||||
|
||||
impl G1Affine {
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
|
||||
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
|
||||
G1 {
|
||||
inner: ark_bls12_381::g1::G1Projective::msm(
|
||||
unsafe {
|
||||
&*(bases as *const [G1Affine] as *const [ark_bls12_381::g1::G1Affine])
|
||||
},
|
||||
unsafe { &*(scalars as *const [Zp] as *const [ark_bls12_381::Fr]) },
|
||||
)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
@@ -179,6 +212,39 @@ mod g1 {
|
||||
mod g2 {
|
||||
use super::*;
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G2Affine {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
pub(crate) inner: ark_bls12_381::g2::G2Affine,
|
||||
}
|
||||
|
||||
impl G2Affine {
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
|
||||
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
|
||||
G2 {
|
||||
inner: ark_bls12_381::g2::G2Projective::msm(
|
||||
unsafe {
|
||||
&*(bases as *const [G2Affine] as *const [ark_bls12_381::g2::G2Affine])
|
||||
},
|
||||
unsafe { &*(scalars as *const [Zp] as *const [ark_bls12_381::Fr]) },
|
||||
)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
@@ -193,7 +259,7 @@ mod g2 {
|
||||
#[repr(transparent)]
|
||||
pub struct G2 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
pub(super) inner: ark_bls12_381::G2Projective,
|
||||
pub(crate) inner: ark_bls12_381::G2Projective,
|
||||
}
|
||||
|
||||
impl fmt::Debug for G2 {
|
||||
@@ -640,6 +706,25 @@ mod zp {
|
||||
*value = Zp::from_raw_u64x6(bytes.map(u64::from_le_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hash_128bit(values: &mut [Zp], data: &[&[u8]]) {
|
||||
use sha3::digest::{ExtendableOutput, Update, XofReader};
|
||||
|
||||
let mut hasher = sha3::Shake256::default();
|
||||
for data in data {
|
||||
hasher.update(data);
|
||||
}
|
||||
let mut reader = hasher.finalize_xof();
|
||||
|
||||
for value in values {
|
||||
let mut bytes = [0u8; 2 * 8];
|
||||
reader.read(&mut bytes);
|
||||
let limbs: [u64; 2] = unsafe { core::mem::transmute(bytes) };
|
||||
*value = Zp {
|
||||
inner: BigInt([limbs[0], limbs[1], 0, 0]).into(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for Zp {
|
||||
@@ -714,8 +799,8 @@ mod zp {
|
||||
}
|
||||
}
|
||||
|
||||
pub use g1::G1;
|
||||
pub use g2::G2;
|
||||
pub use g1::{G1Affine, G1};
|
||||
pub use g2::{G2Affine, G2};
|
||||
pub use gt::Gt;
|
||||
pub use zp::Zp;
|
||||
|
||||
|
||||
@@ -36,6 +36,39 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] {
|
||||
mod g1 {
|
||||
use super::*;
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G1Affine {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
pub(crate) inner: crate::curve_446::g1::G1Affine,
|
||||
}
|
||||
|
||||
impl G1Affine {
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 {
|
||||
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
|
||||
G1 {
|
||||
inner: crate::curve_446::g1::G1Projective::msm(
|
||||
unsafe {
|
||||
&*(bases as *const [G1Affine] as *const [crate::curve_446::g1::G1Affine])
|
||||
},
|
||||
unsafe { &*(scalars as *const [Zp] as *const [crate::curve_446::Fr]) },
|
||||
)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
@@ -93,17 +126,17 @@ mod g1 {
|
||||
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self {
|
||||
use rayon::prelude::*;
|
||||
let n_threads = rayon::current_num_threads();
|
||||
let chunk_size = bases.len().div_ceil(n_threads);
|
||||
bases
|
||||
let bases = bases
|
||||
.par_iter()
|
||||
.map(|&x| x.inner.into_affine())
|
||||
.chunks(chunk_size)
|
||||
.zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size))
|
||||
.map(|(bases, scalars)| Self {
|
||||
inner: crate::curve_446::g1::G1Projective::msm(&bases, &scalars).unwrap(),
|
||||
.collect::<Vec<_>>();
|
||||
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
|
||||
Self {
|
||||
inner: crate::curve_446::g1::G1Projective::msm(&bases, unsafe {
|
||||
&*(scalars as *const [Zp] as *const [crate::curve_446::Fr])
|
||||
})
|
||||
.sum::<Self>()
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] {
|
||||
@@ -178,6 +211,39 @@ mod g1 {
|
||||
mod g2 {
|
||||
use super::*;
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G2Affine {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
pub(crate) inner: crate::curve_446::g2::G2Affine,
|
||||
}
|
||||
|
||||
impl G2Affine {
|
||||
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 {
|
||||
// SAFETY: interpreting a `repr(transparent)` pointer as its contents.
|
||||
G2 {
|
||||
inner: crate::curve_446::g2::G2Projective::msm(
|
||||
unsafe {
|
||||
&*(bases as *const [G2Affine] as *const [crate::curve_446::g2::G2Affine])
|
||||
},
|
||||
unsafe { &*(scalars as *const [Zp] as *const [crate::curve_446::Fr]) },
|
||||
)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
@@ -192,7 +258,7 @@ mod g2 {
|
||||
#[repr(transparent)]
|
||||
pub struct G2 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
pub(super) inner: crate::curve_446::g2::G2Projective,
|
||||
pub(crate) inner: crate::curve_446::g2::G2Projective,
|
||||
}
|
||||
|
||||
impl fmt::Debug for G2 {
|
||||
@@ -672,7 +738,7 @@ mod zp {
|
||||
#[repr(transparent)]
|
||||
pub struct Zp {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
pub(crate) inner: crate::curve_446::Fr,
|
||||
pub inner: crate::curve_446::Fr,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Zp {
|
||||
@@ -772,6 +838,25 @@ mod zp {
|
||||
*value = Zp::from_raw_u64x7(bytes.map(u64::from_le_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hash_128bit(values: &mut [Zp], data: &[&[u8]]) {
|
||||
use sha3::digest::{ExtendableOutput, Update, XofReader};
|
||||
|
||||
let mut hasher = sha3::Shake256::default();
|
||||
for data in data {
|
||||
hasher.update(data);
|
||||
}
|
||||
let mut reader = hasher.finalize_xof();
|
||||
|
||||
for value in values {
|
||||
let mut bytes = [0u8; 2 * 8];
|
||||
reader.read(&mut bytes);
|
||||
let limbs: [u64; 2] = unsafe { core::mem::transmute(bytes) };
|
||||
*value = Zp {
|
||||
inner: BigInt([limbs[0], limbs[1], 0, 0, 0]).into(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for Zp {
|
||||
@@ -846,8 +931,8 @@ mod zp {
|
||||
}
|
||||
}
|
||||
|
||||
pub use g1::G1;
|
||||
pub use g2::G2;
|
||||
pub use g1::{G1Affine, G1};
|
||||
pub use g2::{G2Affine, G2};
|
||||
pub use gt::Gt;
|
||||
pub use zp::Zp;
|
||||
|
||||
|
||||
442
tfhe-zk-pok/src/curve_api/msm.rs
Normal file
442
tfhe-zk-pok/src/curve_api/msm.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
use ark_ec::short_weierstrass::Affine;
|
||||
use ark_ec::AffineRepr;
|
||||
use ark_ff::{AdditiveGroup, BigInt, BigInteger, Field, Fp, PrimeField};
|
||||
use rayon::prelude::*;
|
||||
|
||||
fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
|
||||
let scalar = a.as_ref();
|
||||
let radix: u64 = 1 << w;
|
||||
let window_mask: u64 = radix - 1;
|
||||
|
||||
let mut carry = 0u64;
|
||||
let num_bits = if num_bits == 0 {
|
||||
a.num_bits() as usize
|
||||
} else {
|
||||
num_bits
|
||||
};
|
||||
let digits_count = (num_bits + w - 1) / w;
|
||||
|
||||
(0..digits_count).map(move |i| {
|
||||
// Construct a buffer of bits of the scalar, starting at `bit_offset`.
|
||||
let bit_offset = i * w;
|
||||
let u64_idx = bit_offset / 64;
|
||||
let bit_idx = bit_offset % 64;
|
||||
// Read the bits from the scalar
|
||||
let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
|
||||
// This window's bits are contained in a single u64,
|
||||
// or it's the last u64 anyway.
|
||||
scalar[u64_idx] >> bit_idx
|
||||
} else {
|
||||
// Combine the current u64's bits with the bits from the next u64
|
||||
(scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx))
|
||||
};
|
||||
|
||||
// Read the actual coefficient value from the window
|
||||
let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)
|
||||
|
||||
// Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
|
||||
carry = (coef + radix / 2) >> w;
|
||||
let mut digit = (coef as i64) - (carry << w) as i64;
|
||||
|
||||
if i == digits_count - 1 {
|
||||
digit += (carry << w) as i64;
|
||||
}
|
||||
digit
|
||||
})
|
||||
}
|
||||
|
||||
// Compute msm using windowed non-adjacent form
|
||||
pub fn msm_wnaf_g1_446(
|
||||
bases: &[super::bls12_446::G1Affine],
|
||||
scalars: &[super::bls12_446::Zp],
|
||||
) -> super::bls12_446::G1 {
|
||||
use super::bls12_446::*;
|
||||
let num_bits = 299usize;
|
||||
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;
|
||||
|
||||
assert_eq!(bases.len(), scalars.len());
|
||||
|
||||
let size = bases.len();
|
||||
let scalars = &*scalars
|
||||
.par_iter()
|
||||
.map(|x| x.inner.into_bigint())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let c = if size < 32 {
|
||||
3
|
||||
} else {
|
||||
// natural log approx
|
||||
(size.ilog2() as usize * 69 / 100) + 2
|
||||
};
|
||||
|
||||
let digits_count = (num_bits + c - 1) / c;
|
||||
let scalar_digits = scalars
|
||||
.into_par_iter()
|
||||
.flat_map_iter(|s| make_digits(s, c, num_bits))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let zero = G1Affine {
|
||||
inner: Affine::zero(),
|
||||
};
|
||||
|
||||
let window_sums: Vec<_> = (0..digits_count)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let n = 1 << c;
|
||||
let mut indices = vec![vec![]; n];
|
||||
let mut d = vec![BaseField::ZERO; n + 1];
|
||||
let mut e = vec![BaseField::ZERO; n + 1];
|
||||
|
||||
for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
|
||||
use core::cmp::Ordering;
|
||||
// digits is the digits thing of the first scalar?
|
||||
let scalar = digits[i];
|
||||
match 0.cmp(&scalar) {
|
||||
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
|
||||
Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx),
|
||||
Ordering::Equal => (),
|
||||
}
|
||||
}
|
||||
|
||||
let mut buckets = vec![zero; 1 << c];
|
||||
|
||||
loop {
|
||||
d[0] = BaseField::ONE;
|
||||
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
|
||||
if let Some(idx) = idx.last().copied() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let a = value.inner.x - bucket.inner.x;
|
||||
if a != BaseField::ZERO {
|
||||
d[k + 1] = d[k] * a;
|
||||
} else if value.inner.y == bucket.inner.y {
|
||||
d[k + 1] = d[k] * value.inner.y.double();
|
||||
} else {
|
||||
d[k + 1] = d[k];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
d[k + 1] = d[k];
|
||||
}
|
||||
e[n] = d[n].inverse().unwrap();
|
||||
|
||||
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices)
|
||||
.enumerate()
|
||||
.rev()
|
||||
{
|
||||
if let Some(idx) = idx.last().copied() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let a = value.inner.x - bucket.inner.x;
|
||||
if a != BaseField::ZERO {
|
||||
e[k] = e[k + 1] * a;
|
||||
} else if value.inner.y == bucket.inner.y {
|
||||
e[k] = e[k + 1] * value.inner.y.double();
|
||||
} else {
|
||||
e[k] = e[k + 1];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
e[k] = e[k + 1];
|
||||
}
|
||||
|
||||
let d = &d[..n];
|
||||
let e = &e[1..];
|
||||
|
||||
let mut empty = true;
|
||||
for ((&d, &e), (bucket, idx)) in core::iter::zip(
|
||||
core::iter::zip(d, e),
|
||||
core::iter::zip(&mut buckets, &mut indices),
|
||||
) {
|
||||
empty &= idx.len() <= 1;
|
||||
if let Some(idx) = idx.pop() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let x1 = bucket.inner.x;
|
||||
let x2 = value.inner.x;
|
||||
let y1 = bucket.inner.y;
|
||||
let y2 = value.inner.y;
|
||||
|
||||
let eq_x = x1 == x2;
|
||||
|
||||
if eq_x && y1 != y2 {
|
||||
bucket.inner.infinity = true;
|
||||
} else {
|
||||
let r = d * e;
|
||||
let m = if eq_x {
|
||||
let x1 = x1.square();
|
||||
x1 + x1.double()
|
||||
} else {
|
||||
y2 - y1
|
||||
};
|
||||
let m = m * r;
|
||||
|
||||
let x3 = m.square() - x1 - x2;
|
||||
let y3 = m * (x1 - x3) - y1;
|
||||
bucket.inner.x = x3;
|
||||
bucket.inner.y = y3;
|
||||
}
|
||||
} else {
|
||||
*bucket = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if empty {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut running_sum = G1::ZERO;
|
||||
let mut res = G1::ZERO;
|
||||
buckets.into_iter().rev().for_each(|b| {
|
||||
running_sum.inner += b.inner;
|
||||
res += running_sum;
|
||||
});
|
||||
res
|
||||
})
|
||||
.collect();
|
||||
|
||||
// We store the sum for the lowest window.
|
||||
let lowest = *window_sums.first().unwrap();
|
||||
|
||||
// We're traversing windows from high to low.
|
||||
lowest
|
||||
+ window_sums[1..]
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(G1::ZERO, |mut total, &sum_i| {
|
||||
total += sum_i;
|
||||
for _ in 0..c {
|
||||
total = total.double();
|
||||
}
|
||||
total
|
||||
})
|
||||
}
|
||||
|
||||
// Compute msm using windowed non-adjacent form
|
||||
pub fn msm_wnaf_g1_446_extended(
|
||||
bases: &[super::bls12_446::G1Affine],
|
||||
scalars: &[super::bls12_446::Zp],
|
||||
) -> super::bls12_446::G1 {
|
||||
use super::bls12_446::*;
|
||||
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;
|
||||
|
||||
// let num_bits = 75usize;
|
||||
// let mask = BigInt([!0, (1 << 11) - 1, 0, 0, 0]);
|
||||
// let scalars = &*scalars
|
||||
// .par_iter()
|
||||
// .map(|x| x.inner.into_bigint())
|
||||
// .flat_map_iter(|x| (0..4).map(move |i| (x >> (75 * i)) & mask))
|
||||
// .collect::<Vec<_>>();
|
||||
|
||||
let num_bits = 150usize;
|
||||
let mask = BigInt([!0, !0, (1 << 22) - 1, 0, 0]);
|
||||
let scalars = &*scalars
|
||||
.par_iter()
|
||||
.map(|x| x.inner.into_bigint())
|
||||
.flat_map_iter(|x| (0..2).map(move |i| (x >> (150 * i)) & mask))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(bases.len(), scalars.len());
|
||||
|
||||
let size = bases.len();
|
||||
|
||||
let c = if size < 32 {
|
||||
3
|
||||
} else {
|
||||
// natural log approx
|
||||
(size.ilog2() as usize * 69 / 100) + 2
|
||||
};
|
||||
let c = c - 3;
|
||||
|
||||
let digits_count = (num_bits + c - 1) / c;
|
||||
let scalar_digits = scalars
|
||||
.into_par_iter()
|
||||
.flat_map_iter(|s| make_digits(s, c, num_bits))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let zero = G1Affine {
|
||||
inner: Affine::zero(),
|
||||
};
|
||||
|
||||
let window_sums: Vec<_> = (0..digits_count)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let n = 1 << c;
|
||||
let mut indices = vec![vec![]; n];
|
||||
let mut d = vec![BaseField::ZERO; n + 1];
|
||||
let mut e = vec![BaseField::ZERO; n + 1];
|
||||
|
||||
for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
|
||||
use core::cmp::Ordering;
|
||||
// digits is the digits thing of the first scalar?
|
||||
let scalar = digits[i];
|
||||
match 0.cmp(&scalar) {
|
||||
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
|
||||
Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx),
|
||||
Ordering::Equal => (),
|
||||
}
|
||||
}
|
||||
|
||||
let mut buckets = vec![zero; 1 << c];
|
||||
|
||||
loop {
|
||||
d[0] = BaseField::ONE;
|
||||
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
|
||||
if let Some(idx) = idx.last().copied() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let a = value.inner.x - bucket.inner.x;
|
||||
if a != BaseField::ZERO {
|
||||
d[k + 1] = d[k] * a;
|
||||
} else if value.inner.y == bucket.inner.y {
|
||||
d[k + 1] = d[k] * value.inner.y.double();
|
||||
} else {
|
||||
d[k + 1] = d[k];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
d[k + 1] = d[k];
|
||||
}
|
||||
e[n] = d[n].inverse().unwrap();
|
||||
|
||||
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices)
|
||||
.enumerate()
|
||||
.rev()
|
||||
{
|
||||
if let Some(idx) = idx.last().copied() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let a = value.inner.x - bucket.inner.x;
|
||||
if a != BaseField::ZERO {
|
||||
e[k] = e[k + 1] * a;
|
||||
} else if value.inner.y == bucket.inner.y {
|
||||
e[k] = e[k + 1] * value.inner.y.double();
|
||||
} else {
|
||||
e[k] = e[k + 1];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
e[k] = e[k + 1];
|
||||
}
|
||||
|
||||
let d = &d[..n];
|
||||
let e = &e[1..];
|
||||
|
||||
let mut empty = true;
|
||||
for ((&d, &e), (bucket, idx)) in core::iter::zip(
|
||||
core::iter::zip(d, e),
|
||||
core::iter::zip(&mut buckets, &mut indices),
|
||||
) {
|
||||
empty &= idx.len() <= 1;
|
||||
if let Some(idx) = idx.pop() {
|
||||
let value = if idx >> (usize::BITS - 1) == 1 {
|
||||
let mut val = bases[!idx];
|
||||
val.inner.y = -val.inner.y;
|
||||
val
|
||||
} else {
|
||||
bases[idx]
|
||||
};
|
||||
|
||||
if !bucket.inner.infinity {
|
||||
let x1 = bucket.inner.x;
|
||||
let x2 = value.inner.x;
|
||||
let y1 = bucket.inner.y;
|
||||
let y2 = value.inner.y;
|
||||
|
||||
let eq_x = x1 == x2;
|
||||
|
||||
if eq_x && y1 != y2 {
|
||||
bucket.inner.infinity = true;
|
||||
} else {
|
||||
let r = d * e;
|
||||
let m = if eq_x {
|
||||
let x1 = x1.square();
|
||||
x1 + x1.double()
|
||||
} else {
|
||||
y2 - y1
|
||||
};
|
||||
let m = m * r;
|
||||
|
||||
let x3 = m.square() - x1 - x2;
|
||||
let y3 = m * (x1 - x3) - y1;
|
||||
bucket.inner.x = x3;
|
||||
bucket.inner.y = y3;
|
||||
}
|
||||
} else {
|
||||
*bucket = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if empty {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut running_sum = G1::ZERO;
|
||||
let mut res = G1::ZERO;
|
||||
buckets.into_iter().rev().for_each(|b| {
|
||||
running_sum.inner += b.inner;
|
||||
res += running_sum;
|
||||
});
|
||||
res
|
||||
})
|
||||
.collect();
|
||||
|
||||
// We store the sum for the lowest window.
|
||||
let lowest = *window_sums.first().unwrap();
|
||||
|
||||
// We're traversing windows from high to low.
|
||||
lowest
|
||||
+ window_sums[1..]
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(G1::ZERO, |mut total, &sum_i| {
|
||||
total += sum_i;
|
||||
for _ in 0..c {
|
||||
total = total.double();
|
||||
}
|
||||
total
|
||||
})
|
||||
}
|
||||
@@ -6,7 +6,10 @@ pub struct PublicParams<G: Curve> {
|
||||
}
|
||||
|
||||
impl<G: Curve> PublicParams<G> {
|
||||
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
|
||||
pub fn from_vec(
|
||||
g_list: Vec<Affine<G::Zp, G::G1>>,
|
||||
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
g_lists: GroupElements::from_vec(g_list, g_hat_list),
|
||||
}
|
||||
@@ -57,7 +60,7 @@ pub fn commit<G: Curve>(
|
||||
let mut c_hat = g_hat.mul_scalar(gamma);
|
||||
for j in 1..n + 1 {
|
||||
let term = if x[j] != 0 {
|
||||
public.g_lists.g_hat_list[j]
|
||||
G::G2::projective(public.g_lists.g_hat_list[j])
|
||||
} else {
|
||||
G::G2::ZERO
|
||||
};
|
||||
@@ -91,7 +94,7 @@ pub fn prove<G: Curve>(
|
||||
|
||||
let mut c_y = g.mul_scalar(gamma_y);
|
||||
for j in 1..n + 1 {
|
||||
c_y += (g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x[j]));
|
||||
c_y += (G::G1::projective(g_list[n + 1 - j])).mul_scalar(y[j] * G::Zp::from_u64(x[j]));
|
||||
}
|
||||
|
||||
let y_bytes = &*(1..n + 1)
|
||||
@@ -143,7 +146,7 @@ pub fn prove<G: Curve>(
|
||||
|
||||
let mut proof = g.mul_scalar(poly[0]);
|
||||
for i in 1..poly.len() {
|
||||
proof += g_list[i].mul_scalar(poly[i]);
|
||||
proof += G::G1::projective(g_list[i]).mul_scalar(poly[i]);
|
||||
}
|
||||
proof
|
||||
};
|
||||
@@ -190,7 +193,7 @@ pub fn verify<G: Curve>(
|
||||
let numerator = {
|
||||
let mut p = c_y.mul_scalar(delta_y);
|
||||
for i in 1..n + 1 {
|
||||
let gy = g_list[n + 1 - i].mul_scalar(y[i]);
|
||||
let gy = G::G1::projective(g_list[n + 1 - i]).mul_scalar(y[i]);
|
||||
p += gy.mul_scalar(delta_eq).mul_scalar(t[i]) - gy.mul_scalar(delta_y);
|
||||
}
|
||||
e(p, c_hat)
|
||||
@@ -198,7 +201,9 @@ pub fn verify<G: Curve>(
|
||||
let denominator = {
|
||||
let mut q = G::G2::ZERO;
|
||||
for i in 1..n + 1 {
|
||||
q += g_hat_list[i].mul_scalar(delta_eq).mul_scalar(t[i]);
|
||||
q += G::G2::projective(g_hat_list[i])
|
||||
.mul_scalar(delta_eq)
|
||||
.mul_scalar(t[i]);
|
||||
}
|
||||
e(c_y, q)
|
||||
};
|
||||
|
||||
@@ -6,7 +6,10 @@ pub struct PublicParams<G: Curve> {
|
||||
}
|
||||
|
||||
impl<G: Curve> PublicParams<G> {
|
||||
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
|
||||
pub fn from_vec(
|
||||
g_list: Vec<Affine<G::Zp, G::G1>>,
|
||||
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
g_lists: GroupElements::from_vec(g_list, g_hat_list),
|
||||
}
|
||||
@@ -55,7 +58,7 @@ pub fn commit<G: Curve>(
|
||||
|
||||
let mut c = g.mul_scalar(gamma);
|
||||
for j in 1..n + 1 {
|
||||
let term = public.g_lists.g_list[j].mul_scalar(G::Zp::from_u64(m[j]));
|
||||
let term = G::G1::projective(public.g_lists.g_list[j]).mul_scalar(G::Zp::from_u64(m[j]));
|
||||
c += term;
|
||||
}
|
||||
|
||||
@@ -80,11 +83,11 @@ pub fn prove<G: Curve>(
|
||||
let gamma = private.gamma;
|
||||
let g_list = &public.0.g_lists.g_list;
|
||||
|
||||
let mut pi = g_list[n + 1 - i].mul_scalar(gamma);
|
||||
let mut pi = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma);
|
||||
for j in 1..n + 1 {
|
||||
if i != j {
|
||||
let term = if m[j] & 1 == 1 {
|
||||
g_list[n + 1 - i + j]
|
||||
G::G1::projective(g_list[n + 1 - i + j])
|
||||
} else {
|
||||
G::G1::ZERO
|
||||
};
|
||||
@@ -110,8 +113,13 @@ pub fn verify<G: Curve>(
|
||||
let n = public.0.g_lists.message_len;
|
||||
let i = index + 1;
|
||||
|
||||
let lhs = e(c, g_hat_list[n + 1 - i]);
|
||||
let rhs = e(proof.pi, g_hat) + (e(g_list[1], g_hat_list[n])).mul_scalar(G::Zp::from_u64(mi));
|
||||
let lhs = e(c, G::G2::projective(g_hat_list[n + 1 - i]));
|
||||
let rhs = e(proof.pi, g_hat)
|
||||
+ (e(
|
||||
G::G1::projective(g_list[1]),
|
||||
G::G2::projective(g_hat_list[n]),
|
||||
))
|
||||
.mul_scalar(G::Zp::from_u64(mi));
|
||||
|
||||
if lhs == rhs {
|
||||
Ok(())
|
||||
|
||||
@@ -75,6 +75,8 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub type Affine<Zp, Group> = <Group as CurveGroupOps<Zp>>::Affine;
|
||||
|
||||
#[derive(
|
||||
Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize,
|
||||
)]
|
||||
@@ -83,8 +85,8 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
|
||||
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
|
||||
))]
|
||||
struct GroupElements<G: Curve> {
|
||||
g_list: OneBased<Vec<G::G1>>,
|
||||
g_hat_list: OneBased<Vec<G::G2>>,
|
||||
g_list: OneBased<Vec<Affine<G::Zp, G::G1>>>,
|
||||
g_hat_list: OneBased<Vec<Affine<G::Zp, G::G2>>>,
|
||||
message_len: usize,
|
||||
}
|
||||
|
||||
@@ -98,9 +100,9 @@ impl<G: Curve> GroupElements<G> {
|
||||
|
||||
for i in 0..2 * message_len {
|
||||
if i == message_len {
|
||||
g_list.push(G::G1::ZERO);
|
||||
g_list.push(G::G1::ZERO.normalize());
|
||||
} else {
|
||||
g_list.push(g_cur);
|
||||
g_list.push(g_cur.normalize());
|
||||
}
|
||||
g_cur = g_cur.mul_scalar(alpha);
|
||||
}
|
||||
@@ -111,22 +113,22 @@ impl<G: Curve> GroupElements<G> {
|
||||
let mut g_hat_list = Vec::with_capacity(message_len);
|
||||
let mut g_hat_cur = G::G2::GENERATOR.mul_scalar(alpha);
|
||||
for _ in 0..message_len {
|
||||
g_hat_list.push(g_hat_cur);
|
||||
g_hat_cur = (g_hat_cur).mul_scalar(alpha);
|
||||
g_hat_list.push(g_hat_cur.normalize());
|
||||
g_hat_cur = g_hat_cur.mul_scalar(alpha);
|
||||
}
|
||||
g_hat_list
|
||||
},
|
||||
);
|
||||
|
||||
Self {
|
||||
g_list: OneBased::new(g_list),
|
||||
g_hat_list: OneBased::new(g_hat_list),
|
||||
message_len,
|
||||
}
|
||||
Self::from_vec(g_list, g_hat_list)
|
||||
}
|
||||
|
||||
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
|
||||
pub fn from_vec(
|
||||
g_list: Vec<Affine<G::Zp, G::G1>>,
|
||||
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
|
||||
) -> Self {
|
||||
let message_len = g_hat_list.len();
|
||||
|
||||
Self {
|
||||
g_list: OneBased::new(g_list),
|
||||
g_hat_list: OneBased::new(g_hat_list),
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// TODO: refactor copy-pasted code in proof/verify
|
||||
// TODO: ask about metadata in hashing functions
|
||||
|
||||
use super::*;
|
||||
use core::marker::PhantomData;
|
||||
use rayon::prelude::*;
|
||||
@@ -22,8 +25,8 @@ pub struct PublicParams<G: Curve> {
|
||||
impl<G: Curve> PublicParams<G> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn from_vec(
|
||||
g_list: Vec<G::G1>,
|
||||
g_hat_list: Vec<G::G2>,
|
||||
g_list: Vec<Affine<G::Zp, G::G1>>,
|
||||
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
|
||||
big_d: usize,
|
||||
n: usize,
|
||||
d: usize,
|
||||
@@ -246,12 +249,12 @@ pub fn prove<G: Curve>(
|
||||
let mut dot = 0i128;
|
||||
for j in 0..d {
|
||||
let b = if i + j < d {
|
||||
b[d - j - i - 1]
|
||||
b[d - j - i - 1] as i128
|
||||
} else {
|
||||
b[2 * d - j - i - 1].wrapping_neg()
|
||||
-(b[2 * d - j - i - 1] as i128)
|
||||
};
|
||||
|
||||
dot += r[d - j - 1] as i128 * b as i128;
|
||||
dot += r[d - j - 1] as i128 * b;
|
||||
}
|
||||
|
||||
*r2 += dot;
|
||||
@@ -293,8 +296,9 @@ pub fn prove<G: Curve>(
|
||||
|
||||
let mut c_hat = g_hat.mul_scalar(gamma);
|
||||
for j in 1..big_d + 1 {
|
||||
let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO };
|
||||
c_hat += term;
|
||||
if w[j] {
|
||||
c_hat += G::G2::projective(g_hat_list[j]);
|
||||
}
|
||||
}
|
||||
|
||||
let x_bytes = &*[
|
||||
@@ -337,7 +341,7 @@ pub fn prove<G: Curve>(
|
||||
compute_a_theta::<G>(theta0, d, a, k, b, &mut a_theta, t, delta, b_i, b_r, q);
|
||||
|
||||
let mut t = vec![G::Zp::ZERO; n];
|
||||
G::Zp::hash(
|
||||
G::Zp::hash_128bit(
|
||||
&mut t,
|
||||
&[
|
||||
&(1..n + 1)
|
||||
@@ -356,6 +360,7 @@ pub fn prove<G: Curve>(
|
||||
&[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()],
|
||||
);
|
||||
let [delta_eq, delta_y] = delta;
|
||||
let delta = [delta_eq, delta_y, delta_theta];
|
||||
|
||||
let mut poly_0 = vec![G::Zp::ZERO; n + 1];
|
||||
let mut poly_1 = vec![G::Zp::ZERO; big_d + 1];
|
||||
@@ -394,10 +399,11 @@ pub fn prove<G: Curve>(
|
||||
t_theta += theta0[d + i] * G::Zp::from_i64(c2[i]);
|
||||
}
|
||||
|
||||
let mut poly = G::Zp::poly_sub(
|
||||
&G::Zp::poly_mul(&poly_0, &poly_1),
|
||||
&G::Zp::poly_mul(&poly_2, &poly_3),
|
||||
let mul = rayon::join(
|
||||
|| G::Zp::poly_mul(&poly_0, &poly_1),
|
||||
|| G::Zp::poly_mul(&poly_2, &poly_3),
|
||||
);
|
||||
let mut poly = G::Zp::poly_sub(&mul.0, &mul.1);
|
||||
if poly.len() > n + 1 {
|
||||
poly[n + 1] -= t_theta * delta_theta;
|
||||
}
|
||||
@@ -418,6 +424,7 @@ pub fn prove<G: Curve>(
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let c_h = G::G1::multi_mul_scalar(&g_list.0[..n], &scalars);
|
||||
|
||||
let mut z = G::Zp::ZERO;
|
||||
@@ -497,6 +504,7 @@ pub fn prove<G: Curve>(
|
||||
}
|
||||
|
||||
let mut q = vec![G::Zp::ZERO; n];
|
||||
// https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode
|
||||
for i in (0..n).rev() {
|
||||
poly[i] = poly[i] + z * poly[i + 1];
|
||||
q[i] = poly[i + 1];
|
||||
@@ -570,6 +578,10 @@ fn compute_a_theta<G: Curve>(
|
||||
|
||||
let theta1 = &theta0[..d];
|
||||
let theta2 = &theta0[d..];
|
||||
|
||||
let a = a.iter().map(|x| G::Zp::from_i64(*x)).collect::<Vec<_>>();
|
||||
let b = b.iter().map(|x| G::Zp::from_i64(*x)).collect::<Vec<_>>();
|
||||
|
||||
{
|
||||
let a_theta = &mut a_theta[..d];
|
||||
a_theta
|
||||
@@ -579,27 +591,24 @@ fn compute_a_theta<G: Curve>(
|
||||
let mut dot = G::Zp::ZERO;
|
||||
|
||||
for j in 0..d {
|
||||
let a = if i <= j {
|
||||
a[j - i]
|
||||
if i <= j {
|
||||
dot += a[j - i] * theta1[j];
|
||||
} else {
|
||||
a[d + j - i].wrapping_neg()
|
||||
};
|
||||
|
||||
dot += G::Zp::from_i64(a) * theta1[j];
|
||||
dot -= a[(d + j) - i] * theta1[j];
|
||||
}
|
||||
}
|
||||
|
||||
for j in 0..k {
|
||||
let b = if i + j < d {
|
||||
b[d - i - j - 1]
|
||||
if i + j < d {
|
||||
dot += b[d - i - j - 1] * theta2[j];
|
||||
} else {
|
||||
b[2 * d - i - j - 1].wrapping_neg()
|
||||
dot -= b[2 * d - i - j - 1] * theta2[j];
|
||||
};
|
||||
|
||||
dot += G::Zp::from_i64(b) * theta2[j];
|
||||
}
|
||||
*a_theta_i = dot;
|
||||
});
|
||||
}
|
||||
|
||||
let a_theta = &mut a_theta[d..];
|
||||
|
||||
let step = t.ilog2() as usize;
|
||||
@@ -726,7 +735,7 @@ pub fn verify<G: Curve>(
|
||||
}
|
||||
|
||||
let mut t = vec![G::Zp::ZERO; n];
|
||||
G::Zp::hash(
|
||||
G::Zp::hash_128bit(
|
||||
&mut t,
|
||||
&[
|
||||
&(1..n + 1)
|
||||
@@ -745,6 +754,7 @@ pub fn verify<G: Curve>(
|
||||
&[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()],
|
||||
);
|
||||
let [delta_eq, delta_y] = delta;
|
||||
let delta = [delta_eq, delta_y, delta_theta];
|
||||
|
||||
if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) {
|
||||
let mut z = G::Zp::ZERO;
|
||||
@@ -789,7 +799,11 @@ pub fn verify<G: Curve>(
|
||||
if e(pi, G::G2::GENERATOR)
|
||||
!= e(c_y.mul_scalar(delta_y) + c_h, c_hat)
|
||||
- e(c_y.mul_scalar(delta_eq), c_hat_t)
|
||||
- e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta)
|
||||
- e(
|
||||
G::G1::projective(g_list[1]),
|
||||
G::G2::projective(g_hat_list[n]),
|
||||
)
|
||||
.mul_scalar(t_theta * delta_theta)
|
||||
{
|
||||
return Err(());
|
||||
}
|
||||
@@ -822,13 +836,17 @@ pub fn verify<G: Curve>(
|
||||
|
||||
if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR)
|
||||
+ e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w)
|
||||
== e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z))
|
||||
== e(
|
||||
pi_kzg,
|
||||
G::G2::projective(g_hat_list[1]) - G::G2::GENERATOR.mul_scalar(z),
|
||||
)
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
} else {
|
||||
// PERF: rewrite as multi_mul_scalar?
|
||||
let (term0, term1) = rayon::join(
|
||||
|| {
|
||||
let p = c_y.mul_scalar(delta_y)
|
||||
@@ -839,7 +857,7 @@ pub fn verify<G: Curve>(
|
||||
if i < big_d + 1 {
|
||||
factor += delta_theta * a_theta[i - 1];
|
||||
}
|
||||
g_list[n + 1 - i].mul_scalar(factor)
|
||||
G::G1::projective(g_list[n + 1 - i]).mul_scalar(factor)
|
||||
})
|
||||
.sum::<G::G1>();
|
||||
let q = c_hat;
|
||||
@@ -849,14 +867,14 @@ pub fn verify<G: Curve>(
|
||||
let p = c_y;
|
||||
let q = (1..n + 1)
|
||||
.into_par_iter()
|
||||
.map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i]))
|
||||
.map(|i| G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i]))
|
||||
.sum::<G::G2>();
|
||||
e(p, q)
|
||||
},
|
||||
);
|
||||
let term2 = {
|
||||
let p = g_list[1];
|
||||
let q = g_hat_list[n];
|
||||
let p = G::G1::projective(g_list[1]);
|
||||
let q = G::G2::projective(g_hat_list[n]);
|
||||
e(p, q)
|
||||
};
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ pub struct PublicParams<G: Curve> {
|
||||
}
|
||||
|
||||
impl<G: Curve> PublicParams<G> {
|
||||
pub fn from_vec(g_list: Vec<G::G1>, g_hat_list: Vec<G::G2>) -> Self {
|
||||
pub fn from_vec(
|
||||
g_list: Vec<Affine<G::Zp, G::G1>>,
|
||||
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
g_lists: GroupElements::from_vec(g_list, g_hat_list),
|
||||
}
|
||||
@@ -54,7 +57,8 @@ pub fn commit<G: Curve>(
|
||||
let g_hat = G::G2::GENERATOR;
|
||||
|
||||
let r = G::Zp::rand(rng);
|
||||
let v_hat = g_hat.mul_scalar(r) + public.g_lists.g_hat_list[1].mul_scalar(G::Zp::from_u64(x));
|
||||
let v_hat = g_hat.mul_scalar(r)
|
||||
+ G::G2::projective(public.g_lists.g_hat_list[1]).mul_scalar(G::Zp::from_u64(x));
|
||||
|
||||
(PublicCommit { l, v_hat }, PrivateCommit { x, r })
|
||||
}
|
||||
@@ -87,7 +91,7 @@ pub fn prove<G: Curve>(
|
||||
let mut c = g_hat.mul_scalar(gamma);
|
||||
for j in 1..l + 1 {
|
||||
let term = if x_bits[j] != 0 {
|
||||
g_hat_list[j]
|
||||
G::G2::projective(g_hat_list[j])
|
||||
} else {
|
||||
G::G2::ZERO
|
||||
};
|
||||
@@ -96,13 +100,13 @@ pub fn prove<G: Curve>(
|
||||
c
|
||||
};
|
||||
|
||||
let mut proof_x = -g_list[n].mul_scalar(r);
|
||||
let mut proof_x = -G::G1::projective(g_list[n]).mul_scalar(r);
|
||||
for i in 1..l + 1 {
|
||||
let mut term = g_list[n + 1 - i].mul_scalar(gamma);
|
||||
let mut term = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma);
|
||||
for j in 1..l + 1 {
|
||||
if j != i {
|
||||
let term_inner = if x_bits[j] != 0 {
|
||||
g_list[n + 1 - i + j]
|
||||
G::G1::projective(g_list[n + 1 - i + j])
|
||||
} else {
|
||||
G::G1::ZERO
|
||||
};
|
||||
@@ -124,7 +128,7 @@ pub fn prove<G: Curve>(
|
||||
let y = OneBased(y);
|
||||
let mut c_y = g.mul_scalar(gamma_y);
|
||||
for j in 1..l + 1 {
|
||||
c_y += g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
|
||||
c_y += G::G1::projective(g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
|
||||
}
|
||||
|
||||
let y_bytes = &*(1..n + 1)
|
||||
@@ -145,11 +149,11 @@ pub fn prove<G: Curve>(
|
||||
|
||||
let mut proof_eq = G::G1::ZERO;
|
||||
for i in 1..n + 1 {
|
||||
let mut numerator = g_list[n + 1 - i].mul_scalar(gamma);
|
||||
let mut numerator = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma);
|
||||
for j in 1..n + 1 {
|
||||
if j != i {
|
||||
let term = if x_bits[j] != 0 {
|
||||
g_list[n + 1 - i + j]
|
||||
G::G1::projective(g_list[n + 1 - i + j])
|
||||
} else {
|
||||
G::G1::ZERO
|
||||
};
|
||||
@@ -158,10 +162,11 @@ pub fn prove<G: Curve>(
|
||||
}
|
||||
numerator = numerator.mul_scalar(t[i] * y[i]);
|
||||
|
||||
let mut denominator = g_list[i].mul_scalar(gamma_y);
|
||||
let mut denominator = G::G1::projective(g_list[i]).mul_scalar(gamma_y);
|
||||
for j in 1..n + 1 {
|
||||
if j != i {
|
||||
denominator += g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
|
||||
denominator += G::G1::projective(g_list[n + 1 - j + i])
|
||||
.mul_scalar(y[j] * G::Zp::from_u64(x_bits[j]));
|
||||
}
|
||||
}
|
||||
denominator = denominator.mul_scalar(t[i]);
|
||||
@@ -171,14 +176,16 @@ pub fn prove<G: Curve>(
|
||||
|
||||
let mut proof_y = g.mul_scalar(gamma_y);
|
||||
for j in 1..n + 1 {
|
||||
proof_y -= g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
|
||||
proof_y -=
|
||||
G::G1::projective(g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
|
||||
}
|
||||
proof_y = proof_y.mul_scalar(gamma);
|
||||
for i in 1..n + 1 {
|
||||
let mut term = g_list[i].mul_scalar(gamma_y);
|
||||
let mut term = G::G1::projective(g_list[i]).mul_scalar(gamma_y);
|
||||
for j in 1..n + 1 {
|
||||
if j != i {
|
||||
term -= g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
|
||||
term -= G::G1::projective(g_list[n + 1 - j + i])
|
||||
.mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j]));
|
||||
}
|
||||
}
|
||||
let term = if x_bits[i] != 0 { term } else { G::G1::ZERO };
|
||||
@@ -202,7 +209,8 @@ pub fn prove<G: Curve>(
|
||||
let mut proof_v = G::G1::ZERO;
|
||||
for i in 2..n + 1 {
|
||||
proof_v += G::G1::mul_scalar(
|
||||
g_list[n + 1 - i].mul_scalar(r) + g_list[n + 2 - i].mul_scalar(G::Zp::from_u64(x)),
|
||||
G::G1::projective(g_list[n + 1 - i]).mul_scalar(r)
|
||||
+ G::G1::projective(g_list[n + 2 - i]).mul_scalar(G::Zp::from_u64(x)),
|
||||
s[i],
|
||||
);
|
||||
}
|
||||
@@ -300,7 +308,7 @@ pub fn verify<G: Curve>(
|
||||
let numerator = {
|
||||
let mut p = c_y.mul_scalar(delta_y);
|
||||
for i in 1..n + 1 {
|
||||
let g = g_list[n + 1 - i];
|
||||
let g = G::G1::projective(g_list[n + 1 - i]);
|
||||
if i <= l {
|
||||
p += g.mul_scalar(delta_x * G::Zp::from_u64(1 << (i - 1)));
|
||||
}
|
||||
@@ -309,16 +317,16 @@ pub fn verify<G: Curve>(
|
||||
e(p, c_hat)
|
||||
};
|
||||
let denominator_0 = {
|
||||
let mut p = g_list[n].mul_scalar(delta_x);
|
||||
let mut p = G::G1::projective(g_list[n]).mul_scalar(delta_x);
|
||||
for i in 2..n + 1 {
|
||||
p -= g_list[n + 1 - i].mul_scalar(delta_v * s[i]);
|
||||
p -= G::G1::projective(g_list[n + 1 - i]).mul_scalar(delta_v * s[i]);
|
||||
}
|
||||
e(p, v_hat)
|
||||
};
|
||||
let denominator_1 = {
|
||||
let mut q = G::G2::ZERO;
|
||||
for i in 1..n + 1 {
|
||||
q += g_hat_list[i].mul_scalar(delta_eq * t[i]);
|
||||
q += G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i]);
|
||||
}
|
||||
e(c_y, q)
|
||||
};
|
||||
|
||||
@@ -19,8 +19,8 @@ pub struct PublicParams<G: Curve> {
|
||||
|
||||
impl<G: Curve> PublicParams<G> {
|
||||
pub fn from_vec(
|
||||
g_list: Vec<G::G1>,
|
||||
g_hat_list: Vec<G::G2>,
|
||||
g_list: Vec<Affine<G::Zp, G::G1>>,
|
||||
g_hat_list: Vec<Affine<G::Zp, G::G2>>,
|
||||
d: usize,
|
||||
big_n: usize,
|
||||
big_m: usize,
|
||||
@@ -268,7 +268,11 @@ pub fn prove<G: Curve>(
|
||||
|
||||
let mut c_hat = g_hat.mul_scalar(gamma);
|
||||
for j in 1..big_d + 1 {
|
||||
let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO };
|
||||
let term = if w[j] {
|
||||
G::G2::projective(g_hat_list[j])
|
||||
} else {
|
||||
G::G2::ZERO
|
||||
};
|
||||
c_hat += term;
|
||||
}
|
||||
|
||||
@@ -763,7 +767,11 @@ pub fn verify<G: Curve>(
|
||||
if e(pi, G::G2::GENERATOR)
|
||||
!= e(c_y.mul_scalar(delta_y) + c_h, c_hat)
|
||||
- e(c_y.mul_scalar(delta_eq), c_hat_t)
|
||||
- e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta)
|
||||
- e(
|
||||
G::G1::projective(g_list[1]),
|
||||
G::G2::projective(g_hat_list[n]),
|
||||
)
|
||||
.mul_scalar(t_theta * delta_theta)
|
||||
{
|
||||
return Err(());
|
||||
}
|
||||
@@ -796,7 +804,10 @@ pub fn verify<G: Curve>(
|
||||
|
||||
if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR)
|
||||
+ e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w)
|
||||
== e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z))
|
||||
== e(
|
||||
pi_kzg,
|
||||
G::G2::projective(g_hat_list[1]) - G::G2::GENERATOR.mul_scalar(z),
|
||||
)
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
@@ -813,7 +824,7 @@ pub fn verify<G: Curve>(
|
||||
if i < big_d + 1 {
|
||||
factor += delta_theta * a_theta[i - 1];
|
||||
}
|
||||
g_list[n + 1 - i].mul_scalar(factor)
|
||||
G::G1::projective(g_list[n + 1 - i]).mul_scalar(factor)
|
||||
})
|
||||
.sum::<G::G1>();
|
||||
let q = c_hat;
|
||||
@@ -823,14 +834,14 @@ pub fn verify<G: Curve>(
|
||||
let p = c_y;
|
||||
let q = (1..n + 1)
|
||||
.into_par_iter()
|
||||
.map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i]))
|
||||
.map(|i| G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i]))
|
||||
.sum::<G::G2>();
|
||||
e(p, q)
|
||||
},
|
||||
);
|
||||
let term2 = {
|
||||
let p = g_list[1];
|
||||
let q = g_hat_list[n];
|
||||
let p = G::G1::projective(g_list[1]);
|
||||
let q = G::G2::projective(g_hat_list[n]);
|
||||
e(p, q)
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user