diff --git a/tfhe-zk-pok/Cargo.toml b/tfhe-zk-pok/Cargo.toml index 475959546..9d1a65bf9 100644 --- a/tfhe-zk-pok/Cargo.toml +++ b/tfhe-zk-pok/Cargo.toml @@ -12,15 +12,15 @@ description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryp # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ark-bls12-381 = "0.4.0" -ark-ec = "0.4.2" -ark-ff = "0.4.2" -ark-poly = "0.4.2" +ark-bls12-381 = { package = "tfhe-ark-bls12-381", version = "0.4.0" } +ark-ec = { package = "tfhe-ark-ec", version = "0.4.2" } +ark-ff = { package = "tfhe-ark-ff", version = "0.4.3" } +ark-poly = { package = "tfhe-ark-poly", version = "0.4.2" } +ark-serialize = { version = "0.4.2" } rand = "0.8.5" rayon = "1.8.0" sha3 = "0.10.8" serde = { version = "~1.0", features = ["derive"] } -ark-serialize = { version = "0.4.2" } zeroize = "1.7.0" [dev-dependencies] diff --git a/tfhe-zk-pok/src/curve_446/mod.rs b/tfhe-zk-pok/src/curve_446/mod.rs index d6e216c8c..ef6446ed0 100644 --- a/tfhe-zk-pok/src/curve_446/mod.rs +++ b/tfhe-zk-pok/src/curve_446/mod.rs @@ -466,8 +466,8 @@ pub mod g1 { use ark_ec::bls12::Bls12Config; use ark_ec::models::CurveConfig; use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; - use ark_ec::{bls12, AffineRepr, Group}; - use ark_ff::{Field, MontFp, One, PrimeField, Zero}; + use ark_ec::{bls12, AdditiveGroup, AffineRepr, PrimeGroup}; + use ark_ff::{MontFp, One, PrimeField, Zero}; use ark_serialize::{Compress, SerializationError}; use core::ops::Neg; @@ -631,7 +631,7 @@ pub mod g2 { use ark_ec::models::CurveConfig; use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; use ark_ec::{bls12, AffineRepr}; - use ark_ff::{MontFp, Zero}; + use ark_ff::MontFp; use ark_serialize::{Compress, SerializationError}; pub type G2Affine = bls12::G2Affine; diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs index 48dc43ccc..a1a0bdb90 100644 --- a/tfhe-zk-pok/src/curve_api.rs +++ b/tfhe-zk-pok/src/curve_api.rs @@ -1,4 +1,4 @@ -use ark_ec::{CurveGroup, Group, VariableBaseMSM}; +use ark_ec::{AdditiveGroup as Group, CurveGroup, VariableBaseMSM}; use ark_ff::{BigInt, Field, MontFp, Zero}; use ark_poly::univariate::DensePolynomial; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; @@ -37,6 +37,8 @@ impl fmt::Debug for MontIntDisplay<'_, T> { } } +pub mod msm; + pub mod bls12_381; pub mod bls12_446; @@ -44,6 +46,7 @@ pub trait FieldOps: Copy + Send + Sync + + core::fmt::Debug + core::ops::AddAssign + core::ops::SubAssign + core::ops::Add @@ -62,6 +65,7 @@ pub trait FieldOps: fn to_bytes(self) -> impl AsRef<[u8]>; fn rand(rng: &mut dyn rand::RngCore) -> Self; fn hash(values: &mut [Self], data: &[&[u8]]); + fn hash_128bit(values: &mut [Self], data: &[&[u8]]); fn poly_mul(p: &[Self], q: &[Self]) -> Vec; fn poly_sub(p: &[Self], q: &[Self]) -> Vec { use core::iter::zip; @@ -113,10 +117,22 @@ pub trait CurveGroupOps: const GENERATOR: Self; const BYTE_SIZE: usize; + type Affine: Copy + + Send + + Sync + + core::fmt::Debug + + serde::Serialize + + for<'de> serde::Deserialize<'de> + + CanonicalSerialize + + CanonicalDeserialize; + + fn projective(affine: Self::Affine) -> Self; + fn mul_scalar(self, scalar: Zp) -> Self; - fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self; + fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[Zp]) -> Self; fn to_bytes(self) -> impl AsRef<[u8]>; fn double(self) -> Self; + fn normalize(self) -> Self::Affine; } pub trait PairingGroupOps: @@ -164,6 +180,9 @@ impl FieldOps for bls12_381::Zp { fn hash(values: &mut [Self], data: &[&[u8]]) { Self::hash(values, data) } + fn hash_128bit(values: &mut [Self], data: &[&[u8]]) { + Self::hash_128bit(values, data) + } fn poly_mul(p: &[Self], q: &[Self]) -> Vec { let p = p.iter().map(|x| x.inner).collect(); @@ -182,13 +201,20 @@ impl CurveGroupOps for bls12_381::G1 { const ZERO: Self = Self::ZERO; const GENERATOR: Self = Self::GENERATOR; const BYTE_SIZE: usize = Self::BYTE_SIZE; + type Affine = bls12_381::G1Affine; + + fn projective(affine: Self::Affine) -> Self { + Self { + inner: affine.inner.into(), + } + } fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { self.mul_scalar(scalar) } - fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self { - Self::multi_mul_scalar(bases, scalars) + fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self { + Self::Affine::multi_mul_scalar(bases, scalars) } fn to_bytes(self) -> impl AsRef<[u8]> { @@ -198,19 +224,32 @@ impl CurveGroupOps for bls12_381::G1 { fn double(self) -> Self { self.double() } + + fn normalize(self) -> Self::Affine { + Self::Affine { + inner: self.inner.into_affine(), + } + } } impl CurveGroupOps for bls12_381::G2 { const ZERO: Self = Self::ZERO; const GENERATOR: Self = Self::GENERATOR; const BYTE_SIZE: usize = Self::BYTE_SIZE; + type Affine = bls12_381::G2Affine; + + fn projective(affine: Self::Affine) -> Self { + Self { + inner: affine.inner.into(), + } + } fn mul_scalar(self, scalar: bls12_381::Zp) -> Self { self.mul_scalar(scalar) } - fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_381::Zp]) -> Self { - Self::multi_mul_scalar(bases, scalars) + fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_381::Zp]) -> Self { + Self::Affine::multi_mul_scalar(bases, scalars) } fn to_bytes(self) -> impl AsRef<[u8]> { @@ -220,6 +259,12 @@ impl CurveGroupOps for bls12_381::G2 { fn double(self) -> Self { self.double() } + + fn normalize(self) -> Self::Affine { + Self::Affine { + inner: self.inner.into_affine(), + } + } } impl PairingGroupOps for bls12_381::Gt { @@ -254,6 +299,9 @@ impl FieldOps for bls12_446::Zp { fn hash(values: &mut [Self], data: &[&[u8]]) { Self::hash(values, data) } + fn hash_128bit(values: &mut [Self], data: &[&[u8]]) { + Self::hash_128bit(values, data) + } fn poly_mul(p: &[Self], q: &[Self]) -> Vec { let p = p.iter().map(|x| x.inner).collect(); @@ -272,13 +320,21 @@ impl CurveGroupOps for bls12_446::G1 { const ZERO: Self = Self::ZERO; const GENERATOR: Self = Self::GENERATOR; const BYTE_SIZE: usize = Self::BYTE_SIZE; + type Affine = bls12_446::G1Affine; + + fn projective(affine: Self::Affine) -> Self { + Self { + inner: affine.inner.into(), + } + } fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { self.mul_scalar(scalar) } - fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self { - Self::multi_mul_scalar(bases, scalars) + fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self { + msm::msm_wnaf_g1_446(bases, scalars) + // Self::Affine::multi_mul_scalar(bases, scalars) } fn to_bytes(self) -> impl AsRef<[u8]> { @@ -288,19 +344,32 @@ impl CurveGroupOps for bls12_446::G1 { fn double(self) -> Self { self.double() } + + fn normalize(self) -> Self::Affine { + Self::Affine { + inner: self.inner.into_affine(), + } + } } impl CurveGroupOps for bls12_446::G2 { const ZERO: Self = Self::ZERO; const GENERATOR: Self = Self::GENERATOR; const BYTE_SIZE: usize = Self::BYTE_SIZE; + type Affine = bls12_446::G2Affine; + + fn projective(affine: Self::Affine) -> Self { + Self { + inner: affine.inner.into(), + } + } fn mul_scalar(self, scalar: bls12_446::Zp) -> Self { self.mul_scalar(scalar) } - fn multi_mul_scalar(bases: &[Self], scalars: &[bls12_446::Zp]) -> Self { - Self::multi_mul_scalar(bases, scalars) + fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[bls12_446::Zp]) -> Self { + Self::Affine::multi_mul_scalar(bases, scalars) } fn to_bytes(self) -> impl AsRef<[u8]> { @@ -310,6 +379,12 @@ impl CurveGroupOps for bls12_446::G2 { fn double(self) -> Self { self.double() } + + fn normalize(self) -> Self::Affine { + Self::Affine { + inner: self.inner.into_affine(), + } + } } impl PairingGroupOps for bls12_446::Gt { diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs index 8c5bb2c01..db579165b 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_381.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -36,6 +36,39 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] { mod g1 { use super::*; + #[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] + #[repr(transparent)] + pub struct G1Affine { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_bls12_381::g1::G1Affine, + } + + impl G1Affine { + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 { + // SAFETY: interpreting a `repr(transparent)` pointer as its contents. + G1 { + inner: ark_bls12_381::g1::G1Projective::msm( + unsafe { + &*(bases as *const [G1Affine] as *const [ark_bls12_381::g1::G1Affine]) + }, + unsafe { &*(scalars as *const [Zp] as *const [ark_bls12_381::Fr]) }, + ) + .unwrap(), + } + } + } + #[derive( Copy, Clone, @@ -179,6 +212,39 @@ mod g1 { mod g2 { use super::*; + #[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] + #[repr(transparent)] + pub struct G2Affine { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: ark_bls12_381::g2::G2Affine, + } + + impl G2Affine { + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 { + // SAFETY: interpreting a `repr(transparent)` pointer as its contents. + G2 { + inner: ark_bls12_381::g2::G2Projective::msm( + unsafe { + &*(bases as *const [G2Affine] as *const [ark_bls12_381::g2::G2Affine]) + }, + unsafe { &*(scalars as *const [Zp] as *const [ark_bls12_381::Fr]) }, + ) + .unwrap(), + } + } + } + #[derive( Copy, Clone, @@ -193,7 +259,7 @@ mod g2 { #[repr(transparent)] pub struct G2 { #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] - pub(super) inner: ark_bls12_381::G2Projective, + pub(crate) inner: ark_bls12_381::G2Projective, } impl fmt::Debug for G2 { @@ -640,6 +706,25 @@ mod zp { *value = Zp::from_raw_u64x6(bytes.map(u64::from_le_bytes)); } } + + pub fn hash_128bit(values: &mut [Zp], data: &[&[u8]]) { + use sha3::digest::{ExtendableOutput, Update, XofReader}; + + let mut hasher = sha3::Shake256::default(); + for data in data { + hasher.update(data); + } + let mut reader = hasher.finalize_xof(); + + for value in values { + let mut bytes = [0u8; 2 * 8]; + reader.read(&mut bytes); + let limbs: [u64; 2] = unsafe { core::mem::transmute(bytes) }; + *value = Zp { + inner: BigInt([limbs[0], limbs[1], 0, 0]).into(), + }; + } + } } impl Add for Zp { @@ -714,8 +799,8 @@ mod zp { } } -pub use g1::G1; -pub use g2::G2; +pub use g1::{G1Affine, G1}; +pub use g2::{G2Affine, G2}; pub use gt::Gt; pub use zp::Zp; diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index d628ed708..1ea2aa572 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -36,6 +36,39 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] { mod g1 { use super::*; + #[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] + #[repr(transparent)] + pub struct G1Affine { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: crate::curve_446::g1::G1Affine, + } + + impl G1Affine { + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 { + // SAFETY: interpreting a `repr(transparent)` pointer as its contents. + G1 { + inner: crate::curve_446::g1::G1Projective::msm( + unsafe { + &*(bases as *const [G1Affine] as *const [crate::curve_446::g1::G1Affine]) + }, + unsafe { &*(scalars as *const [Zp] as *const [crate::curve_446::Fr]) }, + ) + .unwrap(), + } + } + } + #[derive( Copy, Clone, @@ -93,17 +126,17 @@ mod g1 { pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { use rayon::prelude::*; - let n_threads = rayon::current_num_threads(); - let chunk_size = bases.len().div_ceil(n_threads); - bases + let bases = bases .par_iter() .map(|&x| x.inner.into_affine()) - .chunks(chunk_size) - .zip(scalars.par_iter().map(|&x| x.inner).chunks(chunk_size)) - .map(|(bases, scalars)| Self { - inner: crate::curve_446::g1::G1Projective::msm(&bases, &scalars).unwrap(), + .collect::>(); + // SAFETY: interpreting a `repr(transparent)` pointer as its contents. + Self { + inner: crate::curve_446::g1::G1Projective::msm(&bases, unsafe { + &*(scalars as *const [Zp] as *const [crate::curve_446::Fr]) }) - .sum::() + .unwrap(), + } } pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { @@ -178,6 +211,39 @@ mod g1 { mod g2 { use super::*; + #[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] + #[repr(transparent)] + pub struct G2Affine { + #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] + pub(crate) inner: crate::curve_446::g2::G2Affine, + } + + impl G2Affine { + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 { + // SAFETY: interpreting a `repr(transparent)` pointer as its contents. + G2 { + inner: crate::curve_446::g2::G2Projective::msm( + unsafe { + &*(bases as *const [G2Affine] as *const [crate::curve_446::g2::G2Affine]) + }, + unsafe { &*(scalars as *const [Zp] as *const [crate::curve_446::Fr]) }, + ) + .unwrap(), + } + } + } + #[derive( Copy, Clone, @@ -192,7 +258,7 @@ mod g2 { #[repr(transparent)] pub struct G2 { #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] - pub(super) inner: crate::curve_446::g2::G2Projective, + pub(crate) inner: crate::curve_446::g2::G2Projective, } impl fmt::Debug for G2 { @@ -672,7 +738,7 @@ mod zp { #[repr(transparent)] pub struct Zp { #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] - pub(crate) inner: crate::curve_446::Fr, + pub inner: crate::curve_446::Fr, } impl fmt::Debug for Zp { @@ -772,6 +838,25 @@ mod zp { *value = Zp::from_raw_u64x7(bytes.map(u64::from_le_bytes)); } } + + pub fn hash_128bit(values: &mut [Zp], data: &[&[u8]]) { + use sha3::digest::{ExtendableOutput, Update, XofReader}; + + let mut hasher = sha3::Shake256::default(); + for data in data { + hasher.update(data); + } + let mut reader = hasher.finalize_xof(); + + for value in values { + let mut bytes = [0u8; 2 * 8]; + reader.read(&mut bytes); + let limbs: [u64; 2] = unsafe { core::mem::transmute(bytes) }; + *value = Zp { + inner: BigInt([limbs[0], limbs[1], 0, 0, 0]).into(), + }; + } + } } impl Add for Zp { @@ -846,8 +931,8 @@ mod zp { } } -pub use g1::G1; -pub use g2::G2; +pub use g1::{G1Affine, G1}; +pub use g2::{G2Affine, G2}; pub use gt::Gt; pub use zp::Zp; diff --git a/tfhe-zk-pok/src/curve_api/msm.rs b/tfhe-zk-pok/src/curve_api/msm.rs new file mode 100644 index 000000000..b052512f4 --- /dev/null +++ b/tfhe-zk-pok/src/curve_api/msm.rs @@ -0,0 +1,442 @@ +use ark_ec::short_weierstrass::Affine; +use ark_ec::AffineRepr; +use ark_ff::{AdditiveGroup, BigInt, BigInteger, Field, Fp, PrimeField}; +use rayon::prelude::*; + +fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator + '_ { + let scalar = a.as_ref(); + let radix: u64 = 1 << w; + let window_mask: u64 = radix - 1; + + let mut carry = 0u64; + let num_bits = if num_bits == 0 { + a.num_bits() as usize + } else { + num_bits + }; + let digits_count = (num_bits + w - 1) / w; + + (0..digits_count).map(move |i| { + // Construct a buffer of bits of the scalar, starting at `bit_offset`. + let bit_offset = i * w; + let u64_idx = bit_offset / 64; + let bit_idx = bit_offset % 64; + // Read the bits from the scalar + let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 { + // This window's bits are contained in a single u64, + // or it's the last u64 anyway. + scalar[u64_idx] >> bit_idx + } else { + // Combine the current u64's bits with the bits from the next u64 + (scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx)) + }; + + // Read the actual coefficient value from the window + let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r) + + // Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2) + carry = (coef + radix / 2) >> w; + let mut digit = (coef as i64) - (carry << w) as i64; + + if i == digits_count - 1 { + digit += (carry << w) as i64; + } + digit + }) +} + +// Compute msm using windowed non-adjacent form +pub fn msm_wnaf_g1_446( + bases: &[super::bls12_446::G1Affine], + scalars: &[super::bls12_446::Zp], +) -> super::bls12_446::G1 { + use super::bls12_446::*; + let num_bits = 299usize; + type BaseField = Fp, 7>; + + assert_eq!(bases.len(), scalars.len()); + + let size = bases.len(); + let scalars = &*scalars + .par_iter() + .map(|x| x.inner.into_bigint()) + .collect::>(); + + let c = if size < 32 { + 3 + } else { + // natural log approx + (size.ilog2() as usize * 69 / 100) + 2 + }; + + let digits_count = (num_bits + c - 1) / c; + let scalar_digits = scalars + .into_par_iter() + .flat_map_iter(|s| make_digits(s, c, num_bits)) + .collect::>(); + + let zero = G1Affine { + inner: Affine::zero(), + }; + + let window_sums: Vec<_> = (0..digits_count) + .into_par_iter() + .map(|i| { + let n = 1 << c; + let mut indices = vec![vec![]; n]; + let mut d = vec![BaseField::ZERO; n + 1]; + let mut e = vec![BaseField::ZERO; n + 1]; + + for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() { + use core::cmp::Ordering; + // digits is the digits thing of the first scalar? + let scalar = digits[i]; + match 0.cmp(&scalar) { + Ordering::Less => indices[(scalar - 1) as usize].push(idx), + Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx), + Ordering::Equal => (), + } + } + + let mut buckets = vec![zero; 1 << c]; + + loop { + d[0] = BaseField::ONE; + for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() { + if let Some(idx) = idx.last().copied() { + let value = if idx >> (usize::BITS - 1) == 1 { + let mut val = bases[!idx]; + val.inner.y = -val.inner.y; + val + } else { + bases[idx] + }; + + if !bucket.inner.infinity { + let a = value.inner.x - bucket.inner.x; + if a != BaseField::ZERO { + d[k + 1] = d[k] * a; + } else if value.inner.y == bucket.inner.y { + d[k + 1] = d[k] * value.inner.y.double(); + } else { + d[k + 1] = d[k]; + } + continue; + } + } + d[k + 1] = d[k]; + } + e[n] = d[n].inverse().unwrap(); + + for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices) + .enumerate() + .rev() + { + if let Some(idx) = idx.last().copied() { + let value = if idx >> (usize::BITS - 1) == 1 { + let mut val = bases[!idx]; + val.inner.y = -val.inner.y; + val + } else { + bases[idx] + }; + + if !bucket.inner.infinity { + let a = value.inner.x - bucket.inner.x; + if a != BaseField::ZERO { + e[k] = e[k + 1] * a; + } else if value.inner.y == bucket.inner.y { + e[k] = e[k + 1] * value.inner.y.double(); + } else { + e[k] = e[k + 1]; + } + continue; + } + } + e[k] = e[k + 1]; + } + + let d = &d[..n]; + let e = &e[1..]; + + let mut empty = true; + for ((&d, &e), (bucket, idx)) in core::iter::zip( + core::iter::zip(d, e), + core::iter::zip(&mut buckets, &mut indices), + ) { + empty &= idx.len() <= 1; + if let Some(idx) = idx.pop() { + let value = if idx >> (usize::BITS - 1) == 1 { + let mut val = bases[!idx]; + val.inner.y = -val.inner.y; + val + } else { + bases[idx] + }; + + if !bucket.inner.infinity { + let x1 = bucket.inner.x; + let x2 = value.inner.x; + let y1 = bucket.inner.y; + let y2 = value.inner.y; + + let eq_x = x1 == x2; + + if eq_x && y1 != y2 { + bucket.inner.infinity = true; + } else { + let r = d * e; + let m = if eq_x { + let x1 = x1.square(); + x1 + x1.double() + } else { + y2 - y1 + }; + let m = m * r; + + let x3 = m.square() - x1 - x2; + let y3 = m * (x1 - x3) - y1; + bucket.inner.x = x3; + bucket.inner.y = y3; + } + } else { + *bucket = value; + } + } + } + + if empty { + break; + } + } + + let mut running_sum = G1::ZERO; + let mut res = G1::ZERO; + buckets.into_iter().rev().for_each(|b| { + running_sum.inner += b.inner; + res += running_sum; + }); + res + }) + .collect(); + + // We store the sum for the lowest window. + let lowest = *window_sums.first().unwrap(); + + // We're traversing windows from high to low. + lowest + + window_sums[1..] + .iter() + .rev() + .fold(G1::ZERO, |mut total, &sum_i| { + total += sum_i; + for _ in 0..c { + total = total.double(); + } + total + }) +} + +// Compute msm using windowed non-adjacent form +pub fn msm_wnaf_g1_446_extended( + bases: &[super::bls12_446::G1Affine], + scalars: &[super::bls12_446::Zp], +) -> super::bls12_446::G1 { + use super::bls12_446::*; + type BaseField = Fp, 7>; + + // let num_bits = 75usize; + // let mask = BigInt([!0, (1 << 11) - 1, 0, 0, 0]); + // let scalars = &*scalars + // .par_iter() + // .map(|x| x.inner.into_bigint()) + // .flat_map_iter(|x| (0..4).map(move |i| (x >> (75 * i)) & mask)) + // .collect::>(); + + let num_bits = 150usize; + let mask = BigInt([!0, !0, (1 << 22) - 1, 0, 0]); + let scalars = &*scalars + .par_iter() + .map(|x| x.inner.into_bigint()) + .flat_map_iter(|x| (0..2).map(move |i| (x >> (150 * i)) & mask)) + .collect::>(); + + assert_eq!(bases.len(), scalars.len()); + + let size = bases.len(); + + let c = if size < 32 { + 3 + } else { + // natural log approx + (size.ilog2() as usize * 69 / 100) + 2 + }; + let c = c - 3; + + let digits_count = (num_bits + c - 1) / c; + let scalar_digits = scalars + .into_par_iter() + .flat_map_iter(|s| make_digits(s, c, num_bits)) + .collect::>(); + + let zero = G1Affine { + inner: Affine::zero(), + }; + + let window_sums: Vec<_> = (0..digits_count) + .into_par_iter() + .map(|i| { + let n = 1 << c; + let mut indices = vec![vec![]; n]; + let mut d = vec![BaseField::ZERO; n + 1]; + let mut e = vec![BaseField::ZERO; n + 1]; + + for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() { + use core::cmp::Ordering; + // digits is the digits thing of the first scalar? + let scalar = digits[i]; + match 0.cmp(&scalar) { + Ordering::Less => indices[(scalar - 1) as usize].push(idx), + Ordering::Greater => indices[(-scalar - 1) as usize].push(!idx), + Ordering::Equal => (), + } + } + + let mut buckets = vec![zero; 1 << c]; + + loop { + d[0] = BaseField::ONE; + for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() { + if let Some(idx) = idx.last().copied() { + let value = if idx >> (usize::BITS - 1) == 1 { + let mut val = bases[!idx]; + val.inner.y = -val.inner.y; + val + } else { + bases[idx] + }; + + if !bucket.inner.infinity { + let a = value.inner.x - bucket.inner.x; + if a != BaseField::ZERO { + d[k + 1] = d[k] * a; + } else if value.inner.y == bucket.inner.y { + d[k + 1] = d[k] * value.inner.y.double(); + } else { + d[k + 1] = d[k]; + } + continue; + } + } + d[k + 1] = d[k]; + } + e[n] = d[n].inverse().unwrap(); + + for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices) + .enumerate() + .rev() + { + if let Some(idx) = idx.last().copied() { + let value = if idx >> (usize::BITS - 1) == 1 { + let mut val = bases[!idx]; + val.inner.y = -val.inner.y; + val + } else { + bases[idx] + }; + + if !bucket.inner.infinity { + let a = value.inner.x - bucket.inner.x; + if a != BaseField::ZERO { + e[k] = e[k + 1] * a; + } else if value.inner.y == bucket.inner.y { + e[k] = e[k + 1] * value.inner.y.double(); + } else { + e[k] = e[k + 1]; + } + continue; + } + } + e[k] = e[k + 1]; + } + + let d = &d[..n]; + let e = &e[1..]; + + let mut empty = true; + for ((&d, &e), (bucket, idx)) in core::iter::zip( + core::iter::zip(d, e), + core::iter::zip(&mut buckets, &mut indices), + ) { + empty &= idx.len() <= 1; + if let Some(idx) = idx.pop() { + let value = if idx >> (usize::BITS - 1) == 1 { + let mut val = bases[!idx]; + val.inner.y = -val.inner.y; + val + } else { + bases[idx] + }; + + if !bucket.inner.infinity { + let x1 = bucket.inner.x; + let x2 = value.inner.x; + let y1 = bucket.inner.y; + let y2 = value.inner.y; + + let eq_x = x1 == x2; + + if eq_x && y1 != y2 { + bucket.inner.infinity = true; + } else { + let r = d * e; + let m = if eq_x { + let x1 = x1.square(); + x1 + x1.double() + } else { + y2 - y1 + }; + let m = m * r; + + let x3 = m.square() - x1 - x2; + let y3 = m * (x1 - x3) - y1; + bucket.inner.x = x3; + bucket.inner.y = y3; + } + } else { + *bucket = value; + } + } + } + + if empty { + break; + } + } + + let mut running_sum = G1::ZERO; + let mut res = G1::ZERO; + buckets.into_iter().rev().for_each(|b| { + running_sum.inner += b.inner; + res += running_sum; + }); + res + }) + .collect(); + + // We store the sum for the lowest window. + let lowest = *window_sums.first().unwrap(); + + // We're traversing windows from high to low. + lowest + + window_sums[1..] + .iter() + .rev() + .fold(G1::ZERO, |mut total, &sum_i| { + total += sum_i; + for _ in 0..c { + total = total.double(); + } + total + }) +} diff --git a/tfhe-zk-pok/src/proofs/binary.rs b/tfhe-zk-pok/src/proofs/binary.rs index ea7a1d8be..7c324acbc 100644 --- a/tfhe-zk-pok/src/proofs/binary.rs +++ b/tfhe-zk-pok/src/proofs/binary.rs @@ -6,7 +6,10 @@ pub struct PublicParams { } impl PublicParams { - pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + pub fn from_vec( + g_list: Vec>, + g_hat_list: Vec>, + ) -> Self { Self { g_lists: GroupElements::from_vec(g_list, g_hat_list), } @@ -57,7 +60,7 @@ pub fn commit( let mut c_hat = g_hat.mul_scalar(gamma); for j in 1..n + 1 { let term = if x[j] != 0 { - public.g_lists.g_hat_list[j] + G::G2::projective(public.g_lists.g_hat_list[j]) } else { G::G2::ZERO }; @@ -91,7 +94,7 @@ pub fn prove( let mut c_y = g.mul_scalar(gamma_y); for j in 1..n + 1 { - c_y += (g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x[j])); + c_y += (G::G1::projective(g_list[n + 1 - j])).mul_scalar(y[j] * G::Zp::from_u64(x[j])); } let y_bytes = &*(1..n + 1) @@ -143,7 +146,7 @@ pub fn prove( let mut proof = g.mul_scalar(poly[0]); for i in 1..poly.len() { - proof += g_list[i].mul_scalar(poly[i]); + proof += G::G1::projective(g_list[i]).mul_scalar(poly[i]); } proof }; @@ -190,7 +193,7 @@ pub fn verify( let numerator = { let mut p = c_y.mul_scalar(delta_y); for i in 1..n + 1 { - let gy = g_list[n + 1 - i].mul_scalar(y[i]); + let gy = G::G1::projective(g_list[n + 1 - i]).mul_scalar(y[i]); p += gy.mul_scalar(delta_eq).mul_scalar(t[i]) - gy.mul_scalar(delta_y); } e(p, c_hat) @@ -198,7 +201,9 @@ pub fn verify( let denominator = { let mut q = G::G2::ZERO; for i in 1..n + 1 { - q += g_hat_list[i].mul_scalar(delta_eq).mul_scalar(t[i]); + q += G::G2::projective(g_hat_list[i]) + .mul_scalar(delta_eq) + .mul_scalar(t[i]); } e(c_y, q) }; diff --git a/tfhe-zk-pok/src/proofs/index.rs b/tfhe-zk-pok/src/proofs/index.rs index 641c45140..29061b09b 100644 --- a/tfhe-zk-pok/src/proofs/index.rs +++ b/tfhe-zk-pok/src/proofs/index.rs @@ -6,7 +6,10 @@ pub struct PublicParams { } impl PublicParams { - pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + pub fn from_vec( + g_list: Vec>, + g_hat_list: Vec>, + ) -> Self { Self { g_lists: GroupElements::from_vec(g_list, g_hat_list), } @@ -55,7 +58,7 @@ pub fn commit( let mut c = g.mul_scalar(gamma); for j in 1..n + 1 { - let term = public.g_lists.g_list[j].mul_scalar(G::Zp::from_u64(m[j])); + let term = G::G1::projective(public.g_lists.g_list[j]).mul_scalar(G::Zp::from_u64(m[j])); c += term; } @@ -80,11 +83,11 @@ pub fn prove( let gamma = private.gamma; let g_list = &public.0.g_lists.g_list; - let mut pi = g_list[n + 1 - i].mul_scalar(gamma); + let mut pi = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma); for j in 1..n + 1 { if i != j { let term = if m[j] & 1 == 1 { - g_list[n + 1 - i + j] + G::G1::projective(g_list[n + 1 - i + j]) } else { G::G1::ZERO }; @@ -110,8 +113,13 @@ pub fn verify( let n = public.0.g_lists.message_len; let i = index + 1; - let lhs = e(c, g_hat_list[n + 1 - i]); - let rhs = e(proof.pi, g_hat) + (e(g_list[1], g_hat_list[n])).mul_scalar(G::Zp::from_u64(mi)); + let lhs = e(c, G::G2::projective(g_hat_list[n + 1 - i])); + let rhs = e(proof.pi, g_hat) + + (e( + G::G1::projective(g_list[1]), + G::G2::projective(g_hat_list[n]), + )) + .mul_scalar(G::Zp::from_u64(mi)); if lhs == rhs { Ok(()) diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index 589aae942..2dd235f10 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -75,6 +75,8 @@ impl> IndexMut for OneBased { } } +pub type Affine = >::Affine; + #[derive( Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize, )] @@ -83,8 +85,8 @@ impl> IndexMut for OneBased { serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" ))] struct GroupElements { - g_list: OneBased>, - g_hat_list: OneBased>, + g_list: OneBased>>, + g_hat_list: OneBased>>, message_len: usize, } @@ -98,9 +100,9 @@ impl GroupElements { for i in 0..2 * message_len { if i == message_len { - g_list.push(G::G1::ZERO); + g_list.push(G::G1::ZERO.normalize()); } else { - g_list.push(g_cur); + g_list.push(g_cur.normalize()); } g_cur = g_cur.mul_scalar(alpha); } @@ -111,22 +113,22 @@ impl GroupElements { let mut g_hat_list = Vec::with_capacity(message_len); let mut g_hat_cur = G::G2::GENERATOR.mul_scalar(alpha); for _ in 0..message_len { - g_hat_list.push(g_hat_cur); - g_hat_cur = (g_hat_cur).mul_scalar(alpha); + g_hat_list.push(g_hat_cur.normalize()); + g_hat_cur = g_hat_cur.mul_scalar(alpha); } g_hat_list }, ); - Self { - g_list: OneBased::new(g_list), - g_hat_list: OneBased::new(g_hat_list), - message_len, - } + Self::from_vec(g_list, g_hat_list) } - pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + pub fn from_vec( + g_list: Vec>, + g_hat_list: Vec>, + ) -> Self { let message_len = g_hat_list.len(); + Self { g_list: OneBased::new(g_list), g_hat_list: OneBased::new(g_hat_list), diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index 1682cf3bd..231f05fe1 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -1,3 +1,6 @@ +// TODO: refactor copy-pasted code in proof/verify +// TODO: ask about metadata in hashing functions + use super::*; use core::marker::PhantomData; use rayon::prelude::*; @@ -22,8 +25,8 @@ pub struct PublicParams { impl PublicParams { #[allow(clippy::too_many_arguments)] pub fn from_vec( - g_list: Vec, - g_hat_list: Vec, + g_list: Vec>, + g_hat_list: Vec>, big_d: usize, n: usize, d: usize, @@ -246,12 +249,12 @@ pub fn prove( let mut dot = 0i128; for j in 0..d { let b = if i + j < d { - b[d - j - i - 1] + b[d - j - i - 1] as i128 } else { - b[2 * d - j - i - 1].wrapping_neg() + -(b[2 * d - j - i - 1] as i128) }; - dot += r[d - j - 1] as i128 * b as i128; + dot += r[d - j - 1] as i128 * b; } *r2 += dot; @@ -293,8 +296,9 @@ pub fn prove( let mut c_hat = g_hat.mul_scalar(gamma); for j in 1..big_d + 1 { - let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO }; - c_hat += term; + if w[j] { + c_hat += G::G2::projective(g_hat_list[j]); + } } let x_bytes = &*[ @@ -337,7 +341,7 @@ pub fn prove( compute_a_theta::(theta0, d, a, k, b, &mut a_theta, t, delta, b_i, b_r, q); let mut t = vec![G::Zp::ZERO; n]; - G::Zp::hash( + G::Zp::hash_128bit( &mut t, &[ &(1..n + 1) @@ -356,6 +360,7 @@ pub fn prove( &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], ); let [delta_eq, delta_y] = delta; + let delta = [delta_eq, delta_y, delta_theta]; let mut poly_0 = vec![G::Zp::ZERO; n + 1]; let mut poly_1 = vec![G::Zp::ZERO; big_d + 1]; @@ -394,10 +399,11 @@ pub fn prove( t_theta += theta0[d + i] * G::Zp::from_i64(c2[i]); } - let mut poly = G::Zp::poly_sub( - &G::Zp::poly_mul(&poly_0, &poly_1), - &G::Zp::poly_mul(&poly_2, &poly_3), + let mul = rayon::join( + || G::Zp::poly_mul(&poly_0, &poly_1), + || G::Zp::poly_mul(&poly_2, &poly_3), ); + let mut poly = G::Zp::poly_sub(&mul.0, &mul.1); if poly.len() > n + 1 { poly[n + 1] -= t_theta * delta_theta; } @@ -418,6 +424,7 @@ pub fn prove( } }) .collect::>(); + let c_h = G::G1::multi_mul_scalar(&g_list.0[..n], &scalars); let mut z = G::Zp::ZERO; @@ -497,6 +504,7 @@ pub fn prove( } let mut q = vec![G::Zp::ZERO; n]; + // https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode for i in (0..n).rev() { poly[i] = poly[i] + z * poly[i + 1]; q[i] = poly[i + 1]; @@ -570,6 +578,10 @@ fn compute_a_theta( let theta1 = &theta0[..d]; let theta2 = &theta0[d..]; + + let a = a.iter().map(|x| G::Zp::from_i64(*x)).collect::>(); + let b = b.iter().map(|x| G::Zp::from_i64(*x)).collect::>(); + { let a_theta = &mut a_theta[..d]; a_theta @@ -579,27 +591,24 @@ fn compute_a_theta( let mut dot = G::Zp::ZERO; for j in 0..d { - let a = if i <= j { - a[j - i] + if i <= j { + dot += a[j - i] * theta1[j]; } else { - a[d + j - i].wrapping_neg() - }; - - dot += G::Zp::from_i64(a) * theta1[j]; + dot -= a[(d + j) - i] * theta1[j]; + } } for j in 0..k { - let b = if i + j < d { - b[d - i - j - 1] + if i + j < d { + dot += b[d - i - j - 1] * theta2[j]; } else { - b[2 * d - i - j - 1].wrapping_neg() + dot -= b[2 * d - i - j - 1] * theta2[j]; }; - - dot += G::Zp::from_i64(b) * theta2[j]; } *a_theta_i = dot; }); } + let a_theta = &mut a_theta[d..]; let step = t.ilog2() as usize; @@ -726,7 +735,7 @@ pub fn verify( } let mut t = vec![G::Zp::ZERO; n]; - G::Zp::hash( + G::Zp::hash_128bit( &mut t, &[ &(1..n + 1) @@ -745,6 +754,7 @@ pub fn verify( &[x_bytes, c_hat.to_bytes().as_ref(), c_y.to_bytes().as_ref()], ); let [delta_eq, delta_y] = delta; + let delta = [delta_eq, delta_y, delta_theta]; if let (Some(pi_kzg), Some(c_hat_t), Some(c_h)) = (pi_kzg, c_hat_t, c_h) { let mut z = G::Zp::ZERO; @@ -789,7 +799,11 @@ pub fn verify( if e(pi, G::G2::GENERATOR) != e(c_y.mul_scalar(delta_y) + c_h, c_hat) - e(c_y.mul_scalar(delta_eq), c_hat_t) - - e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta) + - e( + G::G1::projective(g_list[1]), + G::G2::projective(g_hat_list[n]), + ) + .mul_scalar(t_theta * delta_theta) { return Err(()); } @@ -822,13 +836,17 @@ pub fn verify( if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR) + e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w) - == e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z)) + == e( + pi_kzg, + G::G2::projective(g_hat_list[1]) - G::G2::GENERATOR.mul_scalar(z), + ) { Ok(()) } else { Err(()) } } else { + // PERF: rewrite as multi_mul_scalar? let (term0, term1) = rayon::join( || { let p = c_y.mul_scalar(delta_y) @@ -839,7 +857,7 @@ pub fn verify( if i < big_d + 1 { factor += delta_theta * a_theta[i - 1]; } - g_list[n + 1 - i].mul_scalar(factor) + G::G1::projective(g_list[n + 1 - i]).mul_scalar(factor) }) .sum::(); let q = c_hat; @@ -849,14 +867,14 @@ pub fn verify( let p = c_y; let q = (1..n + 1) .into_par_iter() - .map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i])) + .map(|i| G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i])) .sum::(); e(p, q) }, ); let term2 = { - let p = g_list[1]; - let q = g_hat_list[n]; + let p = G::G1::projective(g_list[1]); + let q = G::G2::projective(g_hat_list[n]); e(p, q) }; diff --git a/tfhe-zk-pok/src/proofs/range.rs b/tfhe-zk-pok/src/proofs/range.rs index 499378f32..d26270bf3 100644 --- a/tfhe-zk-pok/src/proofs/range.rs +++ b/tfhe-zk-pok/src/proofs/range.rs @@ -6,7 +6,10 @@ pub struct PublicParams { } impl PublicParams { - pub fn from_vec(g_list: Vec, g_hat_list: Vec) -> Self { + pub fn from_vec( + g_list: Vec>, + g_hat_list: Vec>, + ) -> Self { Self { g_lists: GroupElements::from_vec(g_list, g_hat_list), } @@ -54,7 +57,8 @@ pub fn commit( let g_hat = G::G2::GENERATOR; let r = G::Zp::rand(rng); - let v_hat = g_hat.mul_scalar(r) + public.g_lists.g_hat_list[1].mul_scalar(G::Zp::from_u64(x)); + let v_hat = g_hat.mul_scalar(r) + + G::G2::projective(public.g_lists.g_hat_list[1]).mul_scalar(G::Zp::from_u64(x)); (PublicCommit { l, v_hat }, PrivateCommit { x, r }) } @@ -87,7 +91,7 @@ pub fn prove( let mut c = g_hat.mul_scalar(gamma); for j in 1..l + 1 { let term = if x_bits[j] != 0 { - g_hat_list[j] + G::G2::projective(g_hat_list[j]) } else { G::G2::ZERO }; @@ -96,13 +100,13 @@ pub fn prove( c }; - let mut proof_x = -g_list[n].mul_scalar(r); + let mut proof_x = -G::G1::projective(g_list[n]).mul_scalar(r); for i in 1..l + 1 { - let mut term = g_list[n + 1 - i].mul_scalar(gamma); + let mut term = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma); for j in 1..l + 1 { if j != i { let term_inner = if x_bits[j] != 0 { - g_list[n + 1 - i + j] + G::G1::projective(g_list[n + 1 - i + j]) } else { G::G1::ZERO }; @@ -124,7 +128,7 @@ pub fn prove( let y = OneBased(y); let mut c_y = g.mul_scalar(gamma_y); for j in 1..l + 1 { - c_y += g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); + c_y += G::G1::projective(g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); } let y_bytes = &*(1..n + 1) @@ -145,11 +149,11 @@ pub fn prove( let mut proof_eq = G::G1::ZERO; for i in 1..n + 1 { - let mut numerator = g_list[n + 1 - i].mul_scalar(gamma); + let mut numerator = G::G1::projective(g_list[n + 1 - i]).mul_scalar(gamma); for j in 1..n + 1 { if j != i { let term = if x_bits[j] != 0 { - g_list[n + 1 - i + j] + G::G1::projective(g_list[n + 1 - i + j]) } else { G::G1::ZERO }; @@ -158,10 +162,11 @@ pub fn prove( } numerator = numerator.mul_scalar(t[i] * y[i]); - let mut denominator = g_list[i].mul_scalar(gamma_y); + let mut denominator = G::G1::projective(g_list[i]).mul_scalar(gamma_y); for j in 1..n + 1 { if j != i { - denominator += g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); + denominator += G::G1::projective(g_list[n + 1 - j + i]) + .mul_scalar(y[j] * G::Zp::from_u64(x_bits[j])); } } denominator = denominator.mul_scalar(t[i]); @@ -171,14 +176,16 @@ pub fn prove( let mut proof_y = g.mul_scalar(gamma_y); for j in 1..n + 1 { - proof_y -= g_list[n + 1 - j].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); + proof_y -= + G::G1::projective(g_list[n + 1 - j]).mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); } proof_y = proof_y.mul_scalar(gamma); for i in 1..n + 1 { - let mut term = g_list[i].mul_scalar(gamma_y); + let mut term = G::G1::projective(g_list[i]).mul_scalar(gamma_y); for j in 1..n + 1 { if j != i { - term -= g_list[n + 1 - j + i].mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); + term -= G::G1::projective(g_list[n + 1 - j + i]) + .mul_scalar(y[j] * G::Zp::from_u64(1 - x_bits[j])); } } let term = if x_bits[i] != 0 { term } else { G::G1::ZERO }; @@ -202,7 +209,8 @@ pub fn prove( let mut proof_v = G::G1::ZERO; for i in 2..n + 1 { proof_v += G::G1::mul_scalar( - g_list[n + 1 - i].mul_scalar(r) + g_list[n + 2 - i].mul_scalar(G::Zp::from_u64(x)), + G::G1::projective(g_list[n + 1 - i]).mul_scalar(r) + + G::G1::projective(g_list[n + 2 - i]).mul_scalar(G::Zp::from_u64(x)), s[i], ); } @@ -300,7 +308,7 @@ pub fn verify( let numerator = { let mut p = c_y.mul_scalar(delta_y); for i in 1..n + 1 { - let g = g_list[n + 1 - i]; + let g = G::G1::projective(g_list[n + 1 - i]); if i <= l { p += g.mul_scalar(delta_x * G::Zp::from_u64(1 << (i - 1))); } @@ -309,16 +317,16 @@ pub fn verify( e(p, c_hat) }; let denominator_0 = { - let mut p = g_list[n].mul_scalar(delta_x); + let mut p = G::G1::projective(g_list[n]).mul_scalar(delta_x); for i in 2..n + 1 { - p -= g_list[n + 1 - i].mul_scalar(delta_v * s[i]); + p -= G::G1::projective(g_list[n + 1 - i]).mul_scalar(delta_v * s[i]); } e(p, v_hat) }; let denominator_1 = { let mut q = G::G2::ZERO; for i in 1..n + 1 { - q += g_hat_list[i].mul_scalar(delta_eq * t[i]); + q += G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i]); } e(c_y, q) }; diff --git a/tfhe-zk-pok/src/proofs/rlwe.rs b/tfhe-zk-pok/src/proofs/rlwe.rs index 5dc2eb345..de7e8077f 100644 --- a/tfhe-zk-pok/src/proofs/rlwe.rs +++ b/tfhe-zk-pok/src/proofs/rlwe.rs @@ -19,8 +19,8 @@ pub struct PublicParams { impl PublicParams { pub fn from_vec( - g_list: Vec, - g_hat_list: Vec, + g_list: Vec>, + g_hat_list: Vec>, d: usize, big_n: usize, big_m: usize, @@ -268,7 +268,11 @@ pub fn prove( let mut c_hat = g_hat.mul_scalar(gamma); for j in 1..big_d + 1 { - let term = if w[j] { g_hat_list[j] } else { G::G2::ZERO }; + let term = if w[j] { + G::G2::projective(g_hat_list[j]) + } else { + G::G2::ZERO + }; c_hat += term; } @@ -763,7 +767,11 @@ pub fn verify( if e(pi, G::G2::GENERATOR) != e(c_y.mul_scalar(delta_y) + c_h, c_hat) - e(c_y.mul_scalar(delta_eq), c_hat_t) - - e(g_list[1], g_hat_list[n]).mul_scalar(t_theta * delta_theta) + - e( + G::G1::projective(g_list[1]), + G::G2::projective(g_hat_list[n]), + ) + .mul_scalar(t_theta * delta_theta) { return Err(()); } @@ -796,7 +804,10 @@ pub fn verify( if e(c_h - G::G1::GENERATOR.mul_scalar(p_h), G::G2::GENERATOR) + e(G::G1::GENERATOR, c_hat_t - G::G2::GENERATOR.mul_scalar(p_t)).mul_scalar(w) - == e(pi_kzg, g_hat_list[1] - G::G2::GENERATOR.mul_scalar(z)) + == e( + pi_kzg, + G::G2::projective(g_hat_list[1]) - G::G2::GENERATOR.mul_scalar(z), + ) { Ok(()) } else { @@ -813,7 +824,7 @@ pub fn verify( if i < big_d + 1 { factor += delta_theta * a_theta[i - 1]; } - g_list[n + 1 - i].mul_scalar(factor) + G::G1::projective(g_list[n + 1 - i]).mul_scalar(factor) }) .sum::(); let q = c_hat; @@ -823,14 +834,14 @@ pub fn verify( let p = c_y; let q = (1..n + 1) .into_par_iter() - .map(|i| g_hat_list[i].mul_scalar(delta_eq * t[i])) + .map(|i| G::G2::projective(g_hat_list[i]).mul_scalar(delta_eq * t[i])) .sum::(); e(p, q) }, ); let term2 = { - let p = g_list[1]; - let q = g_hat_list[n]; + let p = G::G1::projective(g_list[1]); + let q = G::G2::projective(g_hat_list[n]); e(p, q) };