From 49b8475517e7ccc4fd4e0ac4831574b9022bcbe1 Mon Sep 17 00:00:00 2001 From: exfinen <47593166+exfinen@users.noreply.github.com> Date: Thu, 9 Nov 2023 18:32:16 +0900 Subject: [PATCH] add arith modules based on mcl. add mcl-based groth16 --- Cargo.lock | 18 + Cargo.toml | 2 +- README.md | 5 +- src/building_block/mcl/mcl_fr.rs | 370 +++++ src/building_block/mcl/mcl_g1.rs | 186 +++ src/building_block/mcl/mcl_g2.rs | 199 +++ src/building_block/mcl/mcl_gt.rs | 200 +++ src/building_block/mcl/mcl_initializer.rs | 26 + src/building_block/mcl/mcl_sparse_matrix.rs | 824 +++++++++++ src/building_block/mcl/mcl_sparse_vec.rs | 559 ++++++++ src/building_block/mcl/mod.rs | 11 + src/building_block/mcl/pairing.rs | 130 ++ src/building_block/mcl/polynomial.rs | 1233 +++++++++++++++++ src/building_block/mcl/qap/config.rs | 1 + src/building_block/mcl/qap/constraint.rs | 28 + src/building_block/mcl/qap/equation_parser.rs | 686 +++++++++ src/building_block/mcl/qap/gate.rs | 251 ++++ src/building_block/mcl/qap/gates/adder.rs | 131 ++ .../mcl/qap/gates/arith_circuit.rs | 12 + .../mcl/qap/gates/bool_circuit.rs | 76 + src/building_block/mcl/qap/gates/mod.rs | 4 + src/building_block/mcl/qap/gates/number.rs | 119 ++ src/building_block/mcl/qap/mod.rs | 9 + src/building_block/mcl/qap/qap.rs | 326 +++++ src/building_block/mcl/qap/r1cs.rs | 159 +++ src/building_block/mcl/qap/r1cs_tmpl.rs | 421 ++++++ src/building_block/mcl/qap/term.rs | 33 + src/building_block/mod.rs | 1 + src/zk/w_trusted_setup/groth16_mcl/crs.rs | 138 ++ src/zk/w_trusted_setup/groth16_mcl/mod.rs | 5 + src/zk/w_trusted_setup/groth16_mcl/proof.rs | 12 + src/zk/w_trusted_setup/groth16_mcl/prover.rs | 179 +++ .../w_trusted_setup/groth16_mcl/verifier.rs | 55 + src/zk/w_trusted_setup/groth16_mcl/wires.rs | 78 ++ src/zk/w_trusted_setup/mod.rs | 1 + 35 files changed, 6484 insertions(+), 4 deletions(-) create mode 100644 src/building_block/mcl/mcl_fr.rs create mode 100644 src/building_block/mcl/mcl_g1.rs create mode 100644 src/building_block/mcl/mcl_g2.rs create mode 100644 src/building_block/mcl/mcl_gt.rs create mode 100644 src/building_block/mcl/mcl_initializer.rs create mode 100644 src/building_block/mcl/mcl_sparse_matrix.rs create mode 100644 src/building_block/mcl/mcl_sparse_vec.rs create mode 100644 src/building_block/mcl/mod.rs create mode 100644 src/building_block/mcl/pairing.rs create mode 100644 src/building_block/mcl/polynomial.rs create mode 100644 src/building_block/mcl/qap/config.rs create mode 100644 src/building_block/mcl/qap/constraint.rs create mode 100644 src/building_block/mcl/qap/equation_parser.rs create mode 100644 src/building_block/mcl/qap/gate.rs create mode 100644 src/building_block/mcl/qap/gates/adder.rs create mode 100644 src/building_block/mcl/qap/gates/arith_circuit.rs create mode 100644 src/building_block/mcl/qap/gates/bool_circuit.rs create mode 100644 src/building_block/mcl/qap/gates/mod.rs create mode 100644 src/building_block/mcl/qap/gates/number.rs create mode 100644 src/building_block/mcl/qap/mod.rs create mode 100644 src/building_block/mcl/qap/qap.rs create mode 100644 src/building_block/mcl/qap/r1cs.rs create mode 100644 src/building_block/mcl/qap/r1cs_tmpl.rs create mode 100644 src/building_block/mcl/qap/term.rs create mode 100644 src/zk/w_trusted_setup/groth16_mcl/crs.rs create mode 100644 src/zk/w_trusted_setup/groth16_mcl/mod.rs create mode 100644 src/zk/w_trusted_setup/groth16_mcl/proof.rs create mode 100644 src/zk/w_trusted_setup/groth16_mcl/prover.rs create mode 100644 src/zk/w_trusted_setup/groth16_mcl/verifier.rs create mode 100644 src/zk/w_trusted_setup/groth16_mcl/wires.rs diff --git a/Cargo.lock b/Cargo.lock index 0425ba2..4b65e4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -55,6 +64,14 @@ version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad5c14e80759d0939d013e6ca49930e59fc53dd8e5009132f76240c179380c09" +[[package]] +name = "mcl_rust" +version = "0.0.1" +source = "git+https://github.com/herumi/mcl-rust.git#50056e4f1cd7525042887cbcfe9aecaf8c4858da" +dependencies = [ + "cc", +] + [[package]] name = "memchr" version = "2.5.0" @@ -183,6 +200,7 @@ version = "0.1.0" dependencies = [ "bitvec", "hex", + "mcl_rust", "nom", "num-bigint", "num-traits", diff --git a/Cargo.toml b/Cargo.toml index 2f7e0b5..12cb7d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,4 +12,4 @@ num-traits = "0.2" once_cell = "1.18.0" rand = "0.8.5" rand_chacha = "0.3.1" - +mcl_rust = { git = "https://github.com/herumi/mcl-rust.git" } diff --git a/README.md b/README.md index 128c941..a325608 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,9 @@ To implement cryptographic primitives in the simplest form without using any opt ## What's implemented so far - Groth16 zk-SNARK - - Proof generation and verification + - An implementation fully based on zk-toolkit + - An implementation utilizing BLS12-381 curve of external mcl library - Pinnochio zk-SNARK (protocol 2) - - Proof generation and verification - Common zk-SNARK compnents - Equation parser - R1CS @@ -36,4 +36,3 @@ To implement cryptographic primitives in the simplest form without using any opt ## What's NOT implemented so far - Arbitrary-precision unsigned integer - Random number generator - diff --git a/src/building_block/mcl/mcl_fr.rs b/src/building_block/mcl/mcl_fr.rs new file mode 100644 index 0000000..06b6e50 --- /dev/null +++ b/src/building_block/mcl/mcl_fr.rs @@ -0,0 +1,370 @@ +use mcl_rust::*; +use std::{ + cmp::Ordering, + convert::From, + fmt, + ops::{Add, Sub, Mul, Neg}, + hash::{Hash, Hasher}, +}; +use num_traits::Zero; + +#[derive(Clone)] +pub struct MclFr { + pub v: Fr, +} + +impl MclFr { + pub fn new() -> Self { + let v = Fr::zero(); + MclFr::from(&v) + } + + pub fn inv(&self) -> Self { + let mut v = Fr::zero(); + Fr::inv(&mut v, &self.v); + MclFr::from(&v) + } + + pub fn sq(&self) -> Self { + let mut v = Fr::zero(); + Fr::sqr(&mut v, &self.v); + MclFr::from(&v) + } + + pub fn rand(exclude_zero: bool) -> Self { + let mut v = Fr::zero(); + while { + Fr::set_by_csprng(&mut v); + v.is_zero() && exclude_zero + } {} + MclFr::from(&v) + } + + pub fn inc(&mut self) { + let v = &self.v + &Fr::from_int(1); + self.v = v; + } + + pub fn to_usize(&self) -> usize { + self.v.get_str(10).parse().unwrap() + } +} + +impl Zero for MclFr { + fn is_zero(&self) -> bool { + self.v.is_zero() + } + + fn zero() -> Self { + MclFr { v: Fr::zero() } + } +} + +impl From for MclFr { + fn from(value: i32) -> Self { + let v = Fr::from_int(value); + MclFr { v } + } +} + +impl From for MclFr { + fn from(value: usize) -> Self { + let value: i32 = value as i32; + let v = Fr::from_int(value); + MclFr { v } + } +} + +impl From<&Fr> for MclFr { + fn from(v: &Fr) -> Self { + MclFr { v: v.clone() } + } +} + +impl From<&str> for MclFr { + fn from(s: &str) -> Self { + let mut v = Fr::zero(); + Fr::set_str(&mut v, s, 10); + MclFr { v } + } +} + +impl From for MclFr { + fn from(b: bool) -> Self { + let v = { + if b { + Fr::from_int(1) + } else { + Fr::zero() + } + }; + MclFr { v } + } +} + +impl Ord for MclFr { + fn cmp(&self, other: &Self) -> Ordering { + let r = &self.v.cmp(&other.v); + if r.is_zero() { + Ordering::Equal + } else if r < &0 { + Ordering::Less + } else { + Ordering::Greater + } + } +} + +impl PartialOrd for MclFr { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for MclFr { + fn eq(&self, other: &Self) -> bool { + self.v == other.v + } +} + +impl Hash for MclFr { + fn hash(&self, state: &mut H) { + let mut buf: Vec = vec![]; + let mut v = Fr::zero(); + v.set_hash_of(&mut buf); + buf.hash(state); + } +} + +impl Eq for MclFr {} + +impl fmt::Debug for MclFr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.v.get_str(10)) + } +} + +impl fmt::Display for MclFr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.v.get_str(10)) + } +} + +macro_rules! impl_neg { + ($target: ty) => { + impl Neg for $target { + type Output = MclFr; + + fn neg(self) -> Self::Output { + let mut v = Fr::zero(); + Fr::neg(&mut v, &self.v); + MclFr { v } + } + } + } +} +impl_neg!(MclFr); +impl_neg!(&MclFr); + +macro_rules! impl_add { + ($rhs: ty, $target: ty) => { + impl Add<$rhs> for $target { + type Output = MclFr; + + fn add(self, rhs: $rhs) -> Self::Output { + let mut v = Fr::zero(); + Fr::add(&mut v, &self.v, &rhs.v); + MclFr { v } + } + } + }; +} +impl_add!(MclFr, MclFr); +impl_add!(&MclFr, MclFr); +impl_add!(MclFr, &MclFr); +impl_add!(&MclFr, &MclFr); + +macro_rules! impl_sub { + ($rhs: ty, $target: ty) => { + impl Sub<$rhs> for $target { + type Output = MclFr; + + fn sub(self, rhs: $rhs) -> Self::Output { + let mut v = Fr::zero(); + Fr::sub(&mut v, &self.v, &rhs.v); + MclFr { v } + } + } + }; +} +impl_sub!(MclFr, MclFr); +impl_sub!(&MclFr, MclFr); +impl_sub!(MclFr, &MclFr); +impl_sub!(&MclFr, &MclFr); + +macro_rules! impl_mul { + ($rhs: ty, $target: ty) => { + impl Mul<$rhs> for $target { + type Output = MclFr; + + fn mul(self, rhs: $rhs) -> Self::Output { + let mut v = Fr::zero(); + Fr::mul(&mut v, &self.v, &rhs.v); + MclFr { v } + } + } + }; +} +impl_mul!(MclFr, MclFr); +impl_mul!(&MclFr, MclFr); +impl_mul!(MclFr, &MclFr); +impl_mul!(&MclFr, &MclFr); + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn test_add() { + MclInitializer::init(); + + let n3 = MclFr::from(3); + let n9 = MclFr::from(9); + let exp = MclFr::from(12); + let act = n3 + n9; + assert_eq!(exp, act); + } + + #[test] + fn test_sub() { + MclInitializer::init(); + + let n9 = MclFr::from(9); + let n3 = MclFr::from(3); + let exp = MclFr::from(6); + let act = n9 - n3; + assert_eq!(exp, act); + } + + #[test] + fn test_mul() { + MclInitializer::init(); + + let n3 = MclFr::from(3); + let n9 = MclFr::from(9); + let exp = MclFr::from(27); + let act = n3 * n9; + assert_eq!(exp, act); + } + + + #[test] + fn test_inv() { + MclInitializer::init(); + + let n1 = MclFr::from(1); + let n9 = MclFr::from(9); + let inv9 = n9.inv(); + + assert_eq!(n9 * inv9, n1); + } + + #[test] + fn test_sq() { + MclInitializer::init(); + + let n3 = MclFr::from(3); + let n9 = MclFr::from(9); + + assert_eq!(n3.sq(), n9); + } + + #[test] + fn test_neg() { + MclInitializer::init(); + + let n9 = &MclFr::from(9); + assert_eq!(n9 + -n9, MclFr::zero()); + } + + #[test] + fn test_inc() { + MclInitializer::init(); + + let n1 = MclFr::from(1); + let n2 = MclFr::from(2); + let n3 = MclFr::from(3); + + let mut n = MclFr::zero(); + assert!(n.is_zero()); + n.inc(); + assert_eq!(n, n1); + n.inc(); + assert_eq!(n, n2); + n.inc(); + assert_eq!(n, n3); + } + + #[test] + fn test_ord() { + MclInitializer::init(); + + let n2 = &MclFr::from(2); + let n3 = &MclFr::from(3); + + assert!((n2 == n2) == true); + assert!((n2 != n2) == false); + assert!((n2 < n2) == false); + assert!((n2 > n2) == false); + assert!((n2 >= n2) == true); + assert!((n2 <= n2) == true); + + assert!((n2 == n3) == false); + assert!((n2 != n3) == true); + assert!((n2 < n3) == true); + assert!((n2 > n3) == false); + assert!((n2 >= n3) == false); + assert!((n2 <= n3) == true); + } + + #[test] + fn test_hashing() { + MclInitializer::init(); + use std::collections::HashMap; + + let n2 = &MclFr::from(2); + let n3 = &MclFr::from(3); + let n4 = &MclFr::from(3); + + let m = HashMap::::from([ + ("2".to_string(), n2.clone()), + ("3".to_string(), n3.clone()), + ("4".to_string(), n4.clone()), + ]); + + assert_eq!(m.get("2").unwrap(), n2); + assert_eq!(m.get("3").unwrap(), n3); + assert_eq!(m.get("4").unwrap(), n4); + } +} + + + + + + + + + + + + + + + + + + + + diff --git a/src/building_block/mcl/mcl_g1.rs b/src/building_block/mcl/mcl_g1.rs new file mode 100644 index 0000000..2fe297c --- /dev/null +++ b/src/building_block/mcl/mcl_g1.rs @@ -0,0 +1,186 @@ +use mcl_rust::*; +use std::ops::{Add, Mul, Neg, AddAssign}; +use num_traits::Zero; +use once_cell::sync::Lazy; +use crate::building_block::mcl::mcl_fr::MclFr; + +#[derive(Clone, Debug)] +pub struct MclG1 { + pub v: G1, +} + +static GENERATOR: Lazy = Lazy::new(|| { + let serialized_g = "1 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569"; + let mut v = G1::zero(); + G1::set_str(&mut v, serialized_g, 10); + MclG1 { v } +}); + +impl MclG1 { + pub fn new() -> Self { + let v = G1::zero(); + MclG1::from(&v) + } + + pub fn g() -> MclG1 { + GENERATOR.clone() + } + + pub fn inv(&self) -> Self { + let mut v = G1::zero(); + G1::neg(&mut v, &self.v); + MclG1::from(&v) + } + + pub fn get_random_point() -> MclG1 { + let mut v = Fr::zero(); + v.set_by_csprng(); + MclG1::g() * MclFr::from(&v) + } +} + +impl Zero for MclG1 { + fn zero() -> MclG1 { + let v = G1::zero(); + MclG1::from(&v) + } + + fn is_zero(&self) -> bool { + self.v.is_zero() + } +} + +impl From<&G1> for MclG1 { + fn from(v: &G1) -> Self { + MclG1 { v: v.clone() } + } +} + +macro_rules! impl_add { + ($rhs: ty, $target: ty) => { + impl Add<$rhs> for $target { + type Output = MclG1; + + fn add(self, rhs: $rhs) -> Self::Output { + let mut v = G1::zero(); + G1::add(&mut v, &self.v, &rhs.v); + MclG1::from(&v) + } + } + }; +} +impl_add!(MclG1, MclG1); +impl_add!(&MclG1, MclG1); +impl_add!(MclG1, &MclG1); +impl_add!(&MclG1, &MclG1); + +macro_rules! impl_mul { + ($rhs: ty, $target: ty) => { + impl Mul<$rhs> for $target { + type Output = MclG1; + + fn mul(self, rhs: $rhs) -> Self::Output { + let mut v = G1::zero(); + G1::mul(&mut v, &self.v, &rhs.v); + MclG1::from(&v) + } + } + }; +} +impl_mul!(MclFr, MclG1); +impl_mul!(&MclFr, MclG1); +impl_mul!(MclFr, &MclG1); +impl_mul!(&MclFr, &MclG1); + +impl AddAssign for MclG1 { + fn add_assign(&mut self, rhs: MclG1) { + *self = &*self + rhs + } +} + +impl PartialEq for MclG1 { + fn eq(&self, rhs: &Self) -> bool { + self.v == rhs.v + } +} + +impl Eq for MclG1 {} + +macro_rules! impl_neg { + ($target: ty) => { + impl Neg for $target { + type Output = MclG1; + + fn neg(self) -> Self::Output { + let mut v = G1::zero(); + G1::neg(&mut v, &self.v); + MclG1::from(&v) + } + } + } +} +impl_neg!(MclG1); +impl_neg!(&MclG1); + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn equality() { + MclInitializer::init(); + + let g = MclG1::g(); + let g2 = &g + &g; + + assert_eq!(&g, &g); + assert_eq!(&g2, &g2); + assert_ne!(&g, &g2); + } + + #[test] + fn add() { + MclInitializer::init(); + + let g = &MclG1::g(); + let g2 = &(g + g); + let g4 = &(g2 + g2); + + { + let act = g + g; + let exp = g2; + assert_eq!(&act, exp); + } + { + let act = g2 + g2; + let exp = g4; + assert_eq!(&act, exp); + } + } + + #[test] + fn scalar_mul() { + MclInitializer::init(); + + let g = &MclG1::g(); + let n4 = MclFr::from(4); + let act = g * n4; + let exp = g + g + g + g; + + assert_eq!(act, exp); + } + + #[test] + fn neg() { + MclInitializer::init(); + + let g = &MclG1::g(); + let n4 = &MclFr::from(4); + let g_n4 = g * n4; + let g_n4_neg = (g * n4).neg(); + let act = g_n4 + g_n4_neg; + let exp = MclG1::zero(); + assert_eq!(act, exp); + } +} diff --git a/src/building_block/mcl/mcl_g2.rs b/src/building_block/mcl/mcl_g2.rs new file mode 100644 index 0000000..d872d2b --- /dev/null +++ b/src/building_block/mcl/mcl_g2.rs @@ -0,0 +1,199 @@ +use mcl_rust::*; +use std::ops::{Add, Mul, Neg, AddAssign}; +use num_traits::Zero; +use once_cell::sync::Lazy; +use crate::building_block::mcl::mcl_fr::MclFr; + +#[derive(Clone, Debug)] +pub struct MclG2 { + pub v: G2, +} + +static GENERATOR: Lazy = Lazy::new(|| { + let serialized_g = "1 352701069587466618187139116011060144890029952792775240219908644239793785735715026873347600343865175952761926303160 3059144344244213709971259814753781636986470325476647558659373206291635324768958432433509563104347017837885763365758 1985150602287291935568054521177171638300868978215655730859378665066344726373823718423869104263333984641494340347905 927553665492332455747201965776037880757740193453592970025027978793976877002675564980949289727957565575433344219582"; + let mut v = G2::zero(); + G2::set_str(&mut v, serialized_g, 10); + MclG2 { v } +}); + +impl MclG2 { + pub fn new() -> Self { + let v = G2::zero(); + MclG2::from(&v) + } + + pub fn g() -> MclG2 { + GENERATOR.clone() + } + + pub fn inv(&self) -> Self { + let mut v = G2::zero(); + G2::neg(&mut v, &self.v); + MclG2::from(&v) + } + + pub fn get_random_point() -> MclG2 { + let mut v = Fr::zero(); + v.set_by_csprng(); + MclG2::g() * MclFr::from(&v) + } + + pub fn hash_and_map(buf: &Vec) -> MclG2 { + let mut v = G2::zero(); + G2::set_hash_of(&mut v, buf); + MclG2::from(&v) + } +} + +impl Zero for MclG2 { + fn zero() -> MclG2 { + let v = G2::zero(); + MclG2::from(&v) + } + + fn is_zero(&self) -> bool { + self.v.is_zero() + } +} + +impl From<&G2> for MclG2 { + fn from(v: &G2) -> Self { + MclG2 { v: v.clone() } + } +} + +macro_rules! impl_add { + ($rhs: ty, $target: ty) => { + impl Add<$rhs> for $target { + type Output = MclG2; + + fn add(self, rhs: $rhs) -> Self::Output { + let mut v = G2::zero(); + G2::add(&mut v, &self.v, &rhs.v); + MclG2::from(&v) + } + } + }; +} +impl_add!(MclG2, MclG2); +impl_add!(&MclG2, MclG2); +impl_add!(MclG2, &MclG2); +impl_add!(&MclG2, &MclG2); + +macro_rules! impl_mul { + ($rhs: ty, $target: ty) => { + impl Mul<$rhs> for $target { + type Output = MclG2; + + fn mul(self, rhs: $rhs) -> Self::Output { + let mut v = G2::zero(); + G2::mul(&mut v, &self.v, &rhs.v); + MclG2::from(&v) + } + } + }; +} +impl_mul!(MclFr, MclG2); +impl_mul!(&MclFr, MclG2); +impl_mul!(MclFr, &MclG2); +impl_mul!(&MclFr, &MclG2); + +impl AddAssign for MclG2 { + fn add_assign(&mut self, rhs: MclG2) { + *self = &*self + rhs + } +} + +impl PartialEq for MclG2 { + fn eq(&self, rhs: &Self) -> bool { + self.v == rhs.v + } +} + +impl Eq for MclG2 {} + +macro_rules! impl_neg { + ($target: ty) => { + impl Neg for $target { + type Output = MclG2; + + fn neg(self) -> Self::Output { + let mut v = G2::zero(); + G2::neg(&mut v, &self.v); + MclG2::from(&v) + } + } + } +} +impl_neg!(MclG2); +impl_neg!(&MclG2); + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn equality() { + MclInitializer::init(); + + let g = MclG2::g(); + println!("g {:?}", &g); + let g2 = &g + &g; + + assert_eq!(&g, &g); + assert_eq!(&g2, &g2); + assert_ne!(&g, &g2); + } + + #[test] + fn add() { + MclInitializer::init(); + + let g = &MclG2::g(); + println!("g is {:?}", &g); + let g2 = &(g + g); + let g4 = &(g2 + g2); + + { + let act = g + g; + let exp = g2; + assert_eq!(&act, exp); + } + { + let act = g + g; + let exp = g * MclFr::from(2); + assert_eq!(&act, &exp); + } + { + let act = g2 + g2; + let exp = g4; + assert_eq!(&act, exp); + } + } + + #[test] + fn scalar_mul() { + MclInitializer::init(); + + let g = &MclG2::g(); + let n4 = MclFr::from(4); + let act = g * n4; + let exp = g + g + g + g; + + assert_eq!(act, exp); + } + + #[test] + fn neg() { + MclInitializer::init(); + + let g = &MclG2::g(); + let n4 = &MclFr::from(4); + let g_n4 = g * n4; + let g_n4_neg = (g * n4).neg(); + let act = g_n4 + g_n4_neg; + let exp = MclG2::zero(); + assert_eq!(act, exp); + } +} diff --git a/src/building_block/mcl/mcl_gt.rs b/src/building_block/mcl/mcl_gt.rs new file mode 100644 index 0000000..eac0736 --- /dev/null +++ b/src/building_block/mcl/mcl_gt.rs @@ -0,0 +1,200 @@ +use mcl_rust::*; +use std::{ + convert::From, + fmt, + ops::{Add, + Sub, + Mul, + Neg, + }, +}; +use num_traits::Zero; + +#[derive(Debug, Clone)] +pub struct MclGT { + pub v: GT, +} + +impl MclGT { + pub fn new() -> Self { + let v = GT::zero(); + MclGT::from(&v) + } + + pub fn inv(&self) -> Self { + let mut v = GT::zero(); + GT::inv(&mut v, &self.v); + MclGT::from(&v) + } + + pub fn sq(&self) -> Self { + let mut v = GT::zero(); + GT::sqr(&mut v, &self.v); + MclGT::from(&v) + } +} + +impl Zero for MclGT { + fn is_zero(&self) -> bool { + self.v.is_zero() + } + + fn zero() -> Self { + MclGT::from(>::zero()) + } +} + +impl From for MclGT { + fn from(value: i32) -> Self { + let v = GT::from_int(value); + MclGT { v } + } +} + +impl From<>> for MclGT { + fn from(v: >) -> Self { + MclGT { v: v.clone() } + } +} + +impl PartialEq for MclGT { + fn eq(&self, other: &Self) -> bool { + self.v == other.v + } +} + +impl Eq for MclGT {} + +impl fmt::Display for MclGT { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.v.get_str(16)) + } +} + +macro_rules! impl_neg { + ($target: ty) => { + impl Neg for $target { + type Output = MclGT; + + fn neg(self) -> Self::Output { + let mut v = GT::zero(); + GT::neg(&mut v, &self.v); + MclGT::from(&v) + } + } + } +} +impl_neg!(MclGT); +impl_neg!(&MclGT); + +macro_rules! impl_add { + ($rhs: ty, $target: ty) => { + impl Add<$rhs> for $target { + type Output = MclGT; + + fn add(self, rhs: $rhs) -> Self::Output { + let mut v = GT::zero(); + GT::add(&mut v, &self.v, &rhs.v); + MclGT::from(&v) + } + } + }; +} +impl_add!(MclGT, MclGT); +impl_add!(&MclGT, MclGT); +impl_add!(MclGT, &MclGT); +impl_add!(&MclGT, &MclGT); + +macro_rules! impl_sub { + ($rhs: ty, $target: ty) => { + impl Sub<$rhs> for $target { + type Output = MclGT; + + fn sub(self, rhs: $rhs) -> Self::Output { + let mut v = GT::zero(); + GT::sub(&mut v, &self.v, &rhs.v); + MclGT::from(&v) + } + } + }; +} +impl_sub!(MclGT, MclGT); +impl_sub!(&MclGT, MclGT); +impl_sub!(MclGT, &MclGT); +impl_sub!(&MclGT, &MclGT); + +macro_rules! impl_mul { + ($rhs: ty, $target: ty) => { + impl Mul<$rhs> for $target { + type Output = MclGT; + + fn mul(self, rhs: $rhs) -> Self::Output { + let mut v = GT::zero(); + GT::mul(&mut v, &self.v, &rhs.v); + MclGT { v } + } + } + }; +} +impl_mul!(MclGT, MclGT); +impl_mul!(&MclGT, MclGT); +impl_mul!(MclGT, &MclGT); +impl_mul!(&MclGT, &MclGT); + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn test_add() { + MclInitializer::init(); + + let n3 = MclGT::from(3i32); + let n9 = MclGT::from(9i32); + let exp = MclGT::from(12i32); + let act = n3 + n9; + assert_eq!(exp, act); + } + + #[test] + fn test_sub() { + MclInitializer::init(); + + let n9 = MclGT::from(9i32); + let n3 = MclGT::from(3i32); + let exp = MclGT::from(6i32); + let act = n9 - n3; + assert_eq!(exp, act); + } + + #[test] + fn test_mul() { + MclInitializer::init(); + + let n3 = MclGT::from(3i32); + let n9 = MclGT::from(9i32); + let exp = MclGT::from(27i32); + let act = n3 * n9; + assert_eq!(exp, act); + } + + // #[test] + // fn test_inv() { + // MclInitializer::init(); + // + // let n1 = MclGT::from(1i32); + // let n9 = MclGT::from(9i32); + // let inv9 = n9.inv(); + // + // assert_eq!(n9 * inv9, n1); + // } + + #[test] + fn test_neg() { + MclInitializer::init(); + + let n9 = &MclGT::from(9i32); + assert_eq!(n9 + -n9, MclGT::zero()); + } +} diff --git a/src/building_block/mcl/mcl_initializer.rs b/src/building_block/mcl/mcl_initializer.rs new file mode 100644 index 0000000..482a48f --- /dev/null +++ b/src/building_block/mcl/mcl_initializer.rs @@ -0,0 +1,26 @@ +use mcl_rust::*; +use std::sync::Once; + +static INIT: Once = Once::new(); + +pub struct MclInitializer; + +impl MclInitializer { + pub fn init() { + INIT.call_once(|| { + if !init(CurveType::BLS12_381) { + panic!("Failed to initialize mcl"); + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn init() { + MclInitializer::init(); + } +} diff --git a/src/building_block/mcl/mcl_sparse_matrix.rs b/src/building_block/mcl/mcl_sparse_matrix.rs new file mode 100644 index 0000000..f7d0428 --- /dev/null +++ b/src/building_block/mcl/mcl_sparse_matrix.rs @@ -0,0 +1,824 @@ +use num_traits::Zero; + +use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_sparse_vec::MclSparseVec, + polynomial::Polynomial, +}; +use std::{ + collections::HashMap, + convert::From, + fmt, + ops::Mul, +}; + +pub struct MclSparseMatrix { + pub width: MclFr, + pub height: MclFr, + rows: HashMap, + zero: MclFr, +} + +impl MclSparseMatrix { + pub fn new(width: &MclFr, height: &MclFr) -> Self { + let zero = MclFr::zero(); + let rows = HashMap::new(); + Self { + width: width.clone(), + height: height.clone(), + rows, + zero, + } + } + + pub fn pretty_print(&self) -> String { + let mut s = String::new(); + let empty_row = &MclSparseVec::new(&self.width); + + let mut y = MclFr::zero(); + while y < self.height { + match self.rows.get(&y) { + Some(row) => { + s = format!("{}{}\n", s, row.pretty_print()); + }, + None => { + s = format!("{}{}\n", s, empty_row.pretty_print()); + } + } + y.inc(); + } + s + } + + pub fn multiply_column(&self, col: &MclSparseVec) -> Self { + if col.size != self.height { + panic!("column size is expected to be {:?}, but got {:?}", + self.height, col.size) + } + let mut m = MclSparseMatrix::new(&self.width, &self.height); + + let mut y = MclFr::from(0); + while y < col.size { + let mut x = MclFr::from(0); + let multiplier = col.get(&y); + while x < self.width { + let v = self.get(&x, &y) * multiplier; + m.set(&x, &y, &v); + x.inc(); + } + y.inc(); + } + m + } + + pub fn flatten_rows(&self) -> MclSparseVec { + let mut vec = MclSparseVec::new(&self.width); + + let mut y = MclFr::from(0); + while y < self.height { + println!("y={:?}", &y); + let mut x = MclFr::from(0); + while x < self.width { + println!("x={:?}", &x); + let v = vec.get(&x) + self.get(&x, &y); + vec.set(&x, &v); + x.inc(); + } + y.inc(); + } + vec + } + + pub fn set(&mut self, x: &MclFr, y: &MclFr, v: &MclFr) -> () { + if x >= &self.width || y >= &self.height { + panic!("For {:?} x {:?} matrix, ({:?}, {:?}) is out of range", + self.width, self.height, x, y); + } + if v.is_zero() { // don't set if zero + return; + } + + if !self.rows.contains_key(&y) { + let vec = MclSparseVec::new(&self.width); + self.rows.insert(y.clone(), vec); + } + self.rows.get_mut(&y).unwrap().set(&x, &v); + } + + pub fn get(&self, x: &MclFr, y: &MclFr) -> &MclFr { + if x >= &self.width || y >= &self.height { + panic!("For {:?} x {:?} matrix, ({:?}, {:?}) is out of range", + self.width, self.height, x, y); + } + if !self.rows.contains_key(&y) { + &self.zero + } else { + self.rows.get(&y).unwrap().get(&x) + } + } + + pub fn get_row(&self, y: &MclFr) -> MclSparseVec { + assert!(y < &self.height); + let mut row = MclSparseVec::new(&self.width); + + if !self.rows.contains_key(y) { + return row; + } + let src_row = self.rows.get(y).unwrap(); + for x in src_row.indices() { + let v = src_row.get(&x); + if !v.is_zero() { + row.set(&x, v); + } + } + row + } + + pub fn get_column(&self, x: &MclFr) -> MclSparseVec { + assert!(x < &self.width); + let mut col = MclSparseVec::new(&self.height); + + for y in self.rows.keys() { + let src_row = self.rows.get(&y).unwrap(); + let v = src_row.get(x); + if !v.is_zero() { + col.set(y, v); + } + } + col + } + + pub fn transpose(&self) -> MclSparseMatrix { + let mut m = MclSparseMatrix::new(&self.height, &self.width); + for y in self.rows.keys() { + let src_row = self.rows.get(&y).unwrap(); + + for x in src_row.indices() { + let v = src_row.get(&x); + if !v.is_zero() { + m.set(y, &x, v); + } + } + } + m + } + + // remove empty rows + pub fn normalize(&self) -> MclSparseMatrix { + let mut m = MclSparseMatrix::new(&self.width, &self.height); + for row_key in self.rows.keys() { + let row = self.rows.get(row_key).unwrap(); + if !row.is_empty() { + m.rows.insert(row_key.clone(), row.clone()); + } + } + m + } + + pub fn row_transform(&self, transform: Box MclSparseVec>) -> MclSparseMatrix { + let mut m = MclSparseMatrix::new(&self.width, &self.height); + + let mut y = MclFr::zero(); + while y < self.height { + let in_row = self.get_row(&y); + let out_row = transform(&in_row); + + let mut x = MclFr::zero(); + while x < self.width { + let v = out_row.get(&x); + m.set(&x, &y, v); + x.inc(); + } + y.inc(); + } + m + } +} + +impl PartialEq for MclSparseMatrix { + fn eq(&self, other: &MclSparseMatrix) -> bool { + if self.width != other.width || self.height != other.height { + return false; + } + for key in self.rows.keys() { + if !other.rows.contains_key(key) { + return false; + } + let self_row = &self.rows[key]; + let other_row = &other.rows[key]; + + if self_row != other_row { + return false; + } + } + for key in other.rows.keys() { + if !self.rows.contains_key(key) { + return false; + } + let self_row = &self.rows[key]; + let other_row = &other.rows[key]; + + if self_row != other_row { + return false; + } + } + true + } +} + +impl Into> for MclSparseMatrix { + fn into(self) -> Vec { + let mut vec = vec![]; + let mut i = MclFr::zero(); + while &i < &self.height { + let p = Polynomial::from(&self.get_row(&i)); + vec.push(p); + i.inc(); + } + vec + } +} + +// coverts rows of vectors to a matrix +impl From<&Vec> for MclSparseMatrix { + fn from(rows: &Vec) -> Self { + assert!(rows.len() != 0, "Cannot build matrix from empty vector"); + let width = &rows[0].size; + let height = rows.len(); + + for i in 1..height { + if width != &rows[i].size { + panic!("different row sizes found; size is {:?} at 0, but {:?} at {}", + width, &rows[i].size, i) + } + } + let mut m = MclSparseMatrix::new(width, &MclFr::from(height)); + + for (y, row) in rows.iter().enumerate() { + for x in row.indices() { + let v = row.get(&x); + if !v.is_zero() { + m.set(&x, &MclFr::from(y), v); + } + } + } + m.normalize() + } +} + +impl Mul<&MclSparseMatrix> for &MclSparseMatrix { + type Output = MclSparseMatrix; + + fn mul(self, rhs: &MclSparseMatrix) -> Self::Output { + if self.width != rhs.height { + panic!("Can only multiply matrix with height {:?}, but the rhs height is {:?}", + self.width, rhs.height); + } + let mut res = MclSparseMatrix::new(&rhs.width, &self.height); + + let mut y = MclFr::zero(); + while y < self.height { + let mut x = MclFr::zero(); + while x < rhs.width { + let lhs = &self.get_row(&y); + let rhs = &rhs.get_column(&x); + let v = (lhs * rhs).sum(); + if !v.is_zero() { + res.set(&x, &y, &v); + } + x.inc(); + } + y.inc(); + } + res.normalize() + } +} + +impl fmt::Debug for MclSparseMatrix { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", &self.pretty_print()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn test_size() { + MclInitializer::init(); + let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3)); + assert_eq!(m.width, MclFr::from(2)); + assert_eq!(m.height, MclFr::from(3)); + } + + #[test] + #[should_panic] + fn test_get_x_out_of_range() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + + let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3)); + m.get(&MclFr::from(2), &MclFr::from(1)); + } + + #[test] + #[should_panic] + fn test_get_y_out_of_range() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + + let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3)); + m.get(&MclFr::from(1), &MclFr::from(3)); + } + + #[test] + #[should_panic] + fn test_set_x_out_of_range() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + + let mut m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3)); + m.set(&MclFr::from(2), &MclFr::from(1), &MclFr::from(12)); + } + + #[test] + #[should_panic] + fn test_set_y_out_of_range() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + + let mut m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3)); + m.set(&MclFr::from(1), &MclFr::from(3), &MclFr::from(12)); + } + + #[test] + fn test_get_empty() { + MclInitializer::init(); + let zero = &MclFr::from(0); + + let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3)); + for x in 0..2 { + for y in 0..3 { + assert_eq!(m.get(&MclFr::from(x), &MclFr::from(y)), zero); + } + } + } + + #[test] + fn test_mul() { + MclInitializer::init(); + let zero = &MclFr::from(0); + + let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3)); + for x in 0..2 { + for y in 0..3 { + assert_eq!(m.get(&MclFr::from(x), &MclFr::from(y)), zero); + } + } + } + + #[test] + fn test_get_existing_and_non_existing_cells() { + MclInitializer::init(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + let eight = &MclFr::from(8); + let nine = &MclFr::from(9); + + let mut m = MclSparseMatrix::new(two, three); + m.set(zero, two, nine); + m.set(one, one, eight); + + assert_eq!(m.get(zero, two), nine); + assert_eq!(m.get(one, one), eight); + assert_eq!(m.get(one, two), zero); + } + + #[test] + #[should_panic] + fn test_from_empty_vec() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + let _ = MclSparseMatrix::from(&vec![]); + } + + // |1 0| + // |0 1| + fn gen_test_2x2_identity_matrix() -> MclSparseMatrix { + MclInitializer::init(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + + let mut v1 = MclSparseVec::new(&MclFr::from(2)); + let mut v2 = MclSparseVec::new(&MclFr::from(2)); + v1.set(zero, one); + v2.set(one, one); + let vecs = vec![v1, v2]; + MclSparseMatrix::from(&vecs) + } + + // |1 2| + // |3 4| + fn gen_test_2x2_matrix() -> MclSparseMatrix { + MclInitializer::init(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + let four = &MclFr::from(4); + + let mut v1 = MclSparseVec::new(&MclFr::from(2)); + let mut v2 = MclSparseVec::new(&MclFr::from(2)); + v1.set(zero, one); + v1.set(one, two); + v2.set(zero, three); + v2.set(one, four); + let vecs = vec![v1, v2]; + MclSparseMatrix::from(&vecs) + } + + // |1 0| + // |0 2| + // |3 0| + fn gen_test_2x3_matrix() -> MclSparseMatrix { + MclInitializer::init(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + + let mut v1 = MclSparseVec::new(&MclFr::from(2)); + let mut v2 = MclSparseVec::new(&MclFr::from(2)); + let mut v3 = MclSparseVec::new(&MclFr::from(2)); + v1.set(zero, one); + v2.set(one, two); + v3.set(zero, three); + let vecs = vec![v1, v2, v3]; + MclSparseMatrix::from(&vecs) + } + + // |1 2 3| + // |3 2 1| + fn gen_test_3x2_matrix() -> MclSparseMatrix { + MclInitializer::init(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + + let mut v1 = MclSparseVec::new(&MclFr::from(3)); + let mut v2 = MclSparseVec::new(&MclFr::from(3)); + v1.set(zero, one); + v1.set(one, two); + v1.set(two, three); + v2.set(zero, three); + v2.set(one, two); + v2.set(two, one); + let vecs = vec![v1, v2]; + MclSparseMatrix::from(&vecs) + } + + // |1 2| + // |0 0| + fn gen_test_2x2_redundant_matrix(use_empty_row: bool) -> MclSparseMatrix { + MclInitializer::init(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + + let mut v1 = MclSparseVec::new(&MclFr::from(2)); + v1.set(zero, one); + v1.set(one, two); + + if use_empty_row { + let mut rows = HashMap::::new(); + rows.insert(zero.clone(), v1.clone()); + rows.insert(one.clone(), MclSparseVec::new(&MclFr::from(2))); + + MclSparseMatrix { + width: two.clone(), + height: two.clone(), + rows, + zero: zero.clone(), + } + } else { + let mut rows = HashMap::::new(); + rows.insert(zero.clone(), v1.clone()); + + MclSparseMatrix { + width: two.clone(), + height: two.clone(), + rows, + zero: zero.clone(), + } + } + } + + #[test] + fn test_eq() { + MclInitializer::init(); + { + let m1 = gen_test_2x2_identity_matrix(); + let m2 = gen_test_2x2_matrix(); + let m3 = gen_test_3x2_matrix(); + + assert!(&m1 == &m1); + assert!(&m2 == &m2); + assert!(&m3 == &m3); + } + { + let m1 = gen_test_3x2_matrix(); + let m2 = gen_test_3x2_matrix(); + + assert!(&m1 == &m2); + assert!(&m2 == &m1); + } + } + + #[test] + fn test_eq_with_redundant_matrix() { + MclInitializer::init(); + let m1 = gen_test_2x2_redundant_matrix(true); + let m2 = gen_test_2x2_redundant_matrix(false); + + assert!(&m1 != &m2); + assert!(&m2 != &m1); + } + + #[test] + fn test_non_eq() { + MclInitializer::init(); + let m1 = gen_test_2x2_identity_matrix(); + let m2 = gen_test_2x2_matrix(); + let m3 = gen_test_3x2_matrix(); + + assert!(&m1 != &m2); + assert!(&m2 != &m1); + + assert!(&m1 != &m3); + assert!(&m3 != &m1); + + assert!(&m2 != &m1); + assert!(&m1 != &m2); + + assert!(&m2 != &m3); + assert!(&m3 != &m2); + + assert!(&m3 != &m1); + assert!(&m1 != &m3); + + assert!(&m3 != &m2); + assert!(&m2 != &m3); + } + + #[test] + fn test_normalize() { + MclInitializer::init(); + let m1 = gen_test_2x2_redundant_matrix(true); + let m2 = gen_test_2x2_redundant_matrix(false); + let m1 = m1.normalize(); + + assert!(&m1 == &m2); + assert!(&m2 == &m1); + } + + #[test] + fn test_from_sparse_vecs() { + MclInitializer::init(); + let m = gen_test_2x3_matrix(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + + assert_eq!(&m.width, two); + assert_eq!(&m.height, three); + + assert_eq!(m.get(zero, zero), one); + assert_eq!(m.get(one, one), two); + assert_eq!(m.get(zero, two), three); + } + + #[test] + #[should_panic] + fn test_get_row_x_out_of_range() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + let m = gen_test_2x3_matrix(); + let _ = m.get_row(&m.height); + } + + #[test] + fn test_get_row_x_within_range() { + MclInitializer::init(); + let m = gen_test_2x3_matrix(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + + let r0 = m.get_row(zero); + assert_eq!(r0.get(zero), one); + assert_eq!(r0.get(one), zero); + + let r1 = m.get_row(one); + assert_eq!(r1.get(zero), zero); + assert_eq!(r1.get(one), two); + + let r2 = m.get_row(two); + assert_eq!(r2.get(zero), three); + assert_eq!(r2.get(one), zero); + } + + #[test] + #[should_panic] + fn test_get_column_out_of_range() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + let m = gen_test_2x3_matrix(); + let _ = m.get_column(&m.width); + } + + #[test] + fn test_get_column_within_range() { + MclInitializer::init(); + let m = gen_test_2x3_matrix(); + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + + let c0 = m.get_column(zero); + assert_eq!(c0.get(zero), one); + assert_eq!(c0.get(one), zero); + assert_eq!(c0.get(two), three); + + let c1 = m.get_column(one); + assert_eq!(c1.get(zero), zero); + assert_eq!(c1.get(one), two); + assert_eq!(c1.get(two), zero); + } + + #[test] + #[should_panic] + fn test_mul_incompatible_matrices() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + let m = gen_test_2x3_matrix(); + let _ = &m * &m; + } + + #[test] + fn test_mul_different_sizes() { + MclInitializer::init(); + let m1 = gen_test_3x2_matrix(); + let m2 = gen_test_2x3_matrix(); + let m3 = &m1 * &m2; + + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let six = &MclFr::from(6); + let four = &MclFr::from(4); + let ten = &MclFr::from(10); + + assert_eq!(&m3.width, two); + assert_eq!(&m3.height, two); + assert_eq!(m3.get(zero, zero), ten); + assert_eq!(m3.get(one, zero), four); + assert_eq!(m3.get(zero, one), six); + assert_eq!(m3.get(one, one), four); + } + + #[test] + fn test_mul_identity() { + MclInitializer::init(); + let m1 = gen_test_2x2_matrix(); + let m2 = gen_test_2x2_identity_matrix(); + let m3 = &m1 * &m2; + + let two = &MclFr::from(2); + assert_eq!(&m3.width, two); + assert_eq!(&m3.height, two); + + assert!(&m1 == &m3); + } + + #[test] + fn test_transpose() { + MclInitializer::init(); + let m = &gen_test_3x2_matrix(); + let mt = &m.transpose(); + + assert_eq!(m.height, mt.width); + assert_eq!(m.width, mt.height); + + let mut x = MclFr::zero(); + while x < m.width { + let mut y = MclFr::zero(); + while y < m.height { + let m_v = m.get(&x, &y); + let mt_v = mt.get(&y, &x); + assert_eq!(m_v, mt_v); + y.inc(); + } + x.inc(); + } + } + + #[test] + fn test_row_transform() { + MclInitializer::init(); + // |1 2 3| + // |3 2 1| + let m = gen_test_3x2_matrix(); + + let transform = move |in_vec: &MclSparseVec| { + let one = MclFr::from(1); + let mut out_vec = MclSparseVec::new(&in_vec.size); + let mut i = MclFr::from(0); + while i < in_vec.size { + let v = in_vec.get(&i) + &one; + out_vec.set(&i, &v); + i.inc(); + } + out_vec + }; + let m = m.row_transform(Box::new(transform)); + + { + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + let four = &MclFr::from(4); + + // |2 3 4| + // |4 3 2| + assert_eq!(m.get(zero, zero), two); + assert_eq!(m.get(one, zero), three); + assert_eq!(m.get(two, zero), four); + assert_eq!(m.get(zero, one), four); + assert_eq!(m.get(one, one), three); + assert_eq!(m.get(two, one), two); + } + } + + #[test] + fn test_multiply_column() { + MclInitializer::init(); + // |1 2 3| + // |3 2 1| + let m = gen_test_3x2_matrix(); + + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + let four = &MclFr::from(4); + let six = &MclFr::from(6); + let nine = &MclFr::from(9); + + // |2| + // |3| + let mut col = MclSparseVec::new(&m.height); + col.set(zero, two); + col.set(one, three); + + let m = m.multiply_column(&col); + + // |2 4 6| + // |9 6 3| + assert_eq!(m.get(zero, zero), two); + assert_eq!(m.get(one, zero), four); + assert_eq!(m.get(two, zero), six); + + assert_eq!(m.get(zero, one), nine); + assert_eq!(m.get(one, one), six); + assert_eq!(m.get(two, one), three); + } + + #[test] + fn test_flatten_rows() { + MclInitializer::init(); + // |1 2 3| + // |3 2 1| + let m = gen_test_3x2_matrix(); + + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let four = &MclFr::from(4); + + let row = m.flatten_rows(); + + // |4 2 4| + assert_eq!(row.get(zero), four); + assert_eq!(row.get(one), four); + assert_eq!(row.get(two), four); + } +} diff --git a/src/building_block/mcl/mcl_sparse_vec.rs b/src/building_block/mcl/mcl_sparse_vec.rs new file mode 100644 index 0000000..571fd5e --- /dev/null +++ b/src/building_block/mcl/mcl_sparse_vec.rs @@ -0,0 +1,559 @@ +use std::{ + collections::HashMap, + convert::From, + ops::Mul, +}; +use crate::building_block::mcl::mcl_fr::MclFr; +use num_traits::Zero; +use core::ops::{Index, IndexMut}; + +#[derive(Clone)] +pub struct MclSparseVec { + pub size: MclFr, + zero: MclFr, + elems: HashMap, // HashMap +} + +impl std::fmt::Debug for MclSparseVec { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + let s = self.pretty_print(); + write!(fmt, "{}", s) + } +} + +pub struct MclSparseVecIterator<'a> { + sv: &'a MclSparseVec, + i: MclFr, +} + +impl<'a> Iterator for MclSparseVecIterator<'a> { + type Item = MclFr; + + fn next(&mut self) -> Option { + if &self.sv.size == &self.i { + None + } else { + let elem = self.sv[&self.i].clone(); + self.i.inc(); + Some(elem) + } + } +} + +impl MclSparseVec { + pub fn new(size: &MclFr) -> Self { + if size.is_zero() { + panic!("Size must be greater than 0"); + } + MclSparseVec { + zero: MclFr::zero(), + size: size.clone(), + elems: HashMap::::new(), + } + } + + pub fn iter(&self) -> MclSparseVecIterator { + MclSparseVecIterator { sv: self, i: MclFr::zero() } + } + + pub fn set(&mut self, index: &MclFr, n: &MclFr) { + if index >= &self.size { + panic!("Index {:?} is out of range. The size of vector is {:?}", index, self.size); + } + if !n.is_zero() { + self.elems.insert(index.clone(), n.clone()); + } + } + + pub fn get(&self, index: &MclFr) -> &MclFr { + if index >= &self.size { + panic!("Index {:?} is out of range. The size of vector is {:?}", index, self.size); + } + if self.elems.contains_key(index) { + self.elems.get(index).unwrap() + } else { + &self.zero + } + } + + pub fn indices(&self) -> Vec { + let mut vec = vec![]; + for x in self.elems.keys() { + vec.push(x.clone()); + } + vec + } + + // TODO clean up + pub fn sum(&self) -> MclFr { + let mut values = vec![]; + for value in self.elems.values() { + values.push(value); + } + let mut sum = values[0].clone(); + for value in &values[1..] { + sum = sum + *value; + } + sum + } + + // empty if containing only zeros + pub fn is_empty(&self) -> bool { + for value in self.elems.values() { + if !value.is_zero() { + return false; + } + } + true + } + + pub fn pretty_print(&self) -> String { + let mut s = "[".to_string(); + let mut i = MclFr::zero(); + let one = &MclFr::from(1); + + while &i < &self.size { + s += &format!("{:?}", self.get(&i)); + if &i < &(&self.size - one) { + s += ","; + } + i.inc(); + } + s += "]"; + s + } + + // returns a vector of range [from..to) + pub fn slice(&self, from: &MclFr, to: &MclFr) -> Self { + let size = to - from; + let mut new_sv = MclSparseVec::new(&size); + + let mut i = from.clone(); + while &i < &to { + new_sv.set(&(&i - from), &self[&i]); + i.inc(); + } + new_sv + } + + pub fn concat(&self, other: &MclSparseVec) -> MclSparseVec { + let size = &self.size + &other.size; + let mut sv = MclSparseVec::new(&size); + + let mut i = MclFr::zero(); + // copy self to new sv + { + let mut j = MclFr::zero(); + while &j < &self.size { + sv[&i] = self[&j].clone(); + j.inc(); + i.inc(); + } + } + // copy other to new sv + { + let mut j = MclFr::zero(); + while &j < &other.size { + sv[&i] = other[&j].clone(); + j.inc(); + i.inc(); + } + } + sv + } +} + +impl PartialEq for MclSparseVec { + fn eq(&self, other: &MclSparseVec) -> bool { + if self.size != other.size { return false; } + + for index in self.elems.keys() { + let other_elem = other.get(index); + let this_elem = self.get(index); + if this_elem != other_elem { return false; } + } + for index in other.elems.keys() { + let other_elem = other.get(index); + let this_elem = self.get(index); + if this_elem != other_elem { return false; } + } + true + } +} + +impl Index<&MclFr> for MclSparseVec { + type Output = MclFr; + + fn index(&self, index: &MclFr) -> &Self::Output { + &self.get(index) + } +} + +impl IndexMut<&MclFr> for MclSparseVec { + fn index_mut(&mut self, index: &MclFr) -> &mut Self::Output { + if !self.elems.contains_key(index) { + self.elems.insert(index.clone(), MclFr::from(0)); + } + self.elems.get_mut(index).unwrap() + } +} + +impl From<&Vec> for MclSparseVec { + fn from(elems: &Vec) -> Self { + assert!(elems.len() != 0, "Cannot build vector from empty element list"); + let size = MclFr::from(elems.len()); + let mut vec = MclSparseVec::new(&size); + + for (i, v) in elems.iter().enumerate() { + if !v.is_zero() { + vec.set(&MclFr::from(i), v); + } + } + vec + } +} + +// returns Hadamard product +impl Mul<&MclSparseVec> for &MclSparseVec { + type Output = MclSparseVec; + + fn mul(self, rhs: &MclSparseVec) -> Self::Output { + if self.size != rhs.size { + panic!("Expected size of rhs to be {:?}, but got {:?}", self.size, rhs.size); + } + + let mut ret = MclSparseVec::new(&self.size); + for index in self.elems.keys() { + let l = self.get(index); + let r = rhs.get(index); + if !l.is_zero() && !r.is_zero() { + ret.set(index, &(l * r)); + } + } + ret + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + #[should_panic] + fn test_from_empty_list() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + let _ = MclSparseVec::from(&vec![]); + } + + #[test] + fn test_slice() { + MclInitializer::init(); + let zero = &MclFr::zero(); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + let elems = vec![ + zero.clone(), + one.clone(), + two.clone(), + three.clone(), + ]; + let sv = MclSparseVec::from(&elems); + let sv2 = sv.slice(one, three); + + assert_eq!(&sv2.size, two); + assert_eq!(&sv2[zero], one); + assert_eq!(&sv2[one], two); + } + + #[test] + fn test_from_non_empty_list() { + MclInitializer::init(); + let zero = &MclFr::zero(); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let elems = vec![one.clone(), two.clone()]; + let vec = MclSparseVec::from(&elems); + assert_eq!(&vec.size, two); + assert_eq!(&vec[zero], one); + assert_eq!(&vec[one], two); + } + + #[test] + fn test_from_non_empty_zero_only_list() { + MclInitializer::init(); + let zero = &MclFr::zero(); + let two = &MclFr::from(2); + let elems = vec![zero.clone(), zero.clone()]; + let vec = MclSparseVec::from(&elems); + assert_eq!(&vec.size, two); + assert_eq!(vec.elems.len(), 0); + } + + #[test] + #[should_panic] + fn test_new_empty_vec() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + MclSparseVec::new(&MclFr::zero()); + } + + #[test] + #[should_panic] + fn test_bad_set() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + + let mut vec = MclSparseVec::new(&MclFr::from(3)); + assert_eq!(vec.elems.len(), 0); + + vec.set(&MclFr::from(3), &MclFr::from(2)); + } + + #[test] + fn test_good_set() { + MclInitializer::init(); + let mut vec = MclSparseVec::new(&MclFr::from(3)); + assert_eq!(vec.elems.len(), 0); + + let two = &MclFr::from(2); + vec.set(&MclFr::from(2), two); + assert_eq!(vec.elems.len(), 1); + assert_eq!(vec.elems.get(two).unwrap(), &MclFr::from(2)); + + // setting the same index should overwrite + vec.set(&MclFr::from(2), &MclFr::from(3)); + assert_eq!(vec.elems.len(), 1); + assert_eq!(vec.elems.get(two).unwrap(), &MclFr::from(3)); + + // setting 0 should do nothing + vec.set(&MclFr::from(2), &MclFr::zero()); + assert_eq!(vec.elems.len(), 1); + assert_eq!(vec.elems.get(two).unwrap(), &MclFr::from(3)); + } + + #[test] + fn test_assign() { + MclInitializer::init(); + let mut vec = MclSparseVec::new(&MclFr::from(3)); + assert_eq!(vec.elems.len(), 0); + + vec.set(&MclFr::from(2), &MclFr::from(2)); + assert_eq!(vec.elems.len(), 1); + assert_eq!(vec.elems.get(&MclFr::from(2)).unwrap(), &MclFr::from(2)); + + let indices = vec.indices(); + assert_eq!(indices.len(), 1); + assert_eq!(indices[0], MclFr::from(2)); + + // setting the same index should overwrite + vec.set(&MclFr::from(2), &MclFr::from(3)); + assert_eq!(vec.elems.len(), 1); + assert_eq!(vec.elems.get(&MclFr::from(2)).unwrap(), &MclFr::from(3)); + } + + #[test] + fn test_good_get() { + MclInitializer::init(); + let zero = &MclFr::zero(); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + let mut vec = MclSparseVec::new(&MclFr::from(3)); + vec.set(zero, one); + vec.set(one, two); + vec.set(two, three); + assert_eq!(vec.get(zero), one); + assert_eq!(vec.get(one), two); + assert_eq!(vec.get(two), three); + } + + #[test] + #[should_panic] + fn test_get_index_out_of_range() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + + let vec = MclSparseVec::new(&MclFr::from(1)); + vec.get(&MclFr::from(2)); + } + + #[test] + fn test_get_index_without_value() { + MclInitializer::init(); + std::panic::set_hook(Box::new(|_| {})); // suppress stack trace + + let zero = &MclFr::zero(); + let vec = MclSparseVec::new(&MclFr::from(3)); + assert_eq!(vec.get(&MclFr::from(0)), zero); + assert_eq!(vec.get(&MclFr::from(1)), zero); + assert_eq!(vec.get(&MclFr::from(2)), zero); + } + + #[test] + fn test_indices() { + MclInitializer::init(); + let mut vec = MclSparseVec::new(&MclFr::from(3)); + + vec.set(&MclFr::from(1), &MclFr::from(2)); + vec.set(&MclFr::from(2), &MclFr::from(4)); + + let indices = vec.indices(); + + assert_eq!(indices.len(), 2); + assert!(indices.contains(&MclFr::from(1))); + assert!(indices.contains(&MclFr::from(2))); + } + + #[test] + fn test_mutiply_no_matching_elems() { + MclInitializer::init(); + let mut vec_a = MclSparseVec::new(&MclFr::from(3)); + let mut vec_b = MclSparseVec::new(&MclFr::from(3)); + + vec_a.set(&MclFr::from(1), &MclFr::from(2)); + vec_b.set(&MclFr::from(2), &MclFr::from(3)); + + let vec_c = &vec_a * &vec_b; + + assert_eq!(vec_c.elems.len(), 0); + } + + #[test] + fn test_mutiply_elems_fully_matching_1_elem() { + MclInitializer::init(); + let mut vec_a = MclSparseVec::new(&MclFr::from(3)); + let mut vec_b = MclSparseVec::new(&MclFr::from(3)); + + vec_a.set(&MclFr::from(1), &MclFr::from(2)); + vec_b.set(&MclFr::from(1), &MclFr::from(3)); + + let vec_c = &vec_a * &vec_b; + + assert_eq!(vec_c.elems.len(), 1); + assert_eq!(vec_c.get(&MclFr::from(1)), &MclFr::from(6)); + } + + #[test] + fn test_mutiply_elems_fully_matching_2_elems() { + MclInitializer::init(); + let mut vec_a = MclSparseVec::new(&MclFr::from(3)); + let mut vec_b = MclSparseVec::new(&MclFr::from(3)); + + vec_a.set(&MclFr::from(1), &MclFr::from(2)); + vec_a.set(&MclFr::from(2), &MclFr::from(3)); + vec_b.set(&MclFr::from(1), &MclFr::from(4)); + vec_b.set(&MclFr::from(2), &MclFr::from(5)); + + let vec_c = &vec_a * &vec_b; + + assert_eq!(vec_c.elems.len(), 2); + assert_eq!(vec_c.get(&MclFr::from(1)), &MclFr::from(8)); + assert_eq!(vec_c.get(&MclFr::from(2)), &MclFr::from(15)); + } + + #[test] + fn test_mutiply_elems_partially_matching() { + MclInitializer::init(); + let mut vec_a = MclSparseVec::new(&MclFr::from(3)); + let mut vec_b = MclSparseVec::new(&MclFr::from(3)); + + vec_a.set(&MclFr::from(1), &MclFr::from(2)); + vec_a.set(&MclFr::from(2), &MclFr::from(5)); + vec_b.set(&MclFr::from(1), &MclFr::from(3)); + + let vec_c = &vec_a * &vec_b; + + assert_eq!(vec_c.elems.len(), 1); + assert_eq!(vec_c.get(&MclFr::from(1)), &MclFr::from(6)); + } + + #[test] + fn test_sum() { + MclInitializer::init(); + let mut vec = MclSparseVec::new(&MclFr::from(3)); + + vec.set(&MclFr::from(1), &MclFr::from(2)); + vec.set(&MclFr::from(2), &MclFr::from(4)); + + let sum = vec.sum(); + assert_eq!(sum, MclFr::from(6)); + } + + #[test] + fn test_eq_different_sizes() { + MclInitializer::init(); + let vec_a = MclSparseVec::new(&MclFr::from(3)); + let vec_b = MclSparseVec::new(&MclFr::from(4)); + assert_ne!(vec_a, vec_b); + assert_ne!(vec_b, vec_a); + } + + #[test] + fn test_eq_empty() { + MclInitializer::init(); + let vec_a = MclSparseVec::new(&MclFr::from(3)); + let vec_b = MclSparseVec::new(&MclFr::from(3)); + assert_eq!(vec_a, vec_b); + assert_eq!(vec_b, vec_a); + } + + #[test] + fn test_eq_non_empty() { + MclInitializer::init(); + let mut vec_a = MclSparseVec::new(&MclFr::from(3)); + let mut vec_b = MclSparseVec::new(&MclFr::from(3)); + + vec_a.set(&MclFr::from(1), &MclFr::from(92)); + vec_b.set(&MclFr::from(1), &MclFr::from(92)); + assert_eq!(vec_a, vec_b); + assert_eq!(vec_b, vec_a); + } + + #[test] + fn test_not_eq_non_empty() { + MclInitializer::init(); + let mut vec_a = MclSparseVec::new(&MclFr::from(3)); + let mut vec_b = MclSparseVec::new(&MclFr::from(3)); + + vec_a.set(&MclFr::from(1), &MclFr::from(92)); + vec_b.set(&MclFr::from(1), &MclFr::from(13)); + assert_ne!(vec_a, vec_b); + assert_ne!(vec_b, vec_a); + } + + #[test] + fn test_iterator() { + MclInitializer::init(); + let mut sv = MclSparseVec::new(&MclFr::from(3)); + sv.set(&MclFr::from(0), &MclFr::from(1)); + sv.set(&MclFr::from(1), &MclFr::from(2)); + sv.set(&MclFr::from(2), &MclFr::from(3)); + + let it = &mut sv.iter(); + assert!(&it.next().unwrap() == &MclFr::from(1)); + assert!(&it.next().unwrap() == &MclFr::from(2)); + assert!(&it.next().unwrap() == &MclFr::from(3)); + assert!(it.next() == None); + } + + #[test] + fn test_concat() { + MclInitializer::init(); + let mut sv1 = MclSparseVec::new(&MclFr::from(2)); + sv1.set(&MclFr::from(0), &MclFr::from(1)); + sv1.set(&MclFr::from(1), &MclFr::from(2)); + + let mut sv2 = MclSparseVec::new(&MclFr::from(2)); + sv2.set(&MclFr::from(0), &MclFr::from(3)); + sv2.set(&MclFr::from(1), &MclFr::from(4)); + + let sv3 = sv1.concat(&sv2); + assert!(sv3.get(&MclFr::from(0)) == &MclFr::from(1)); + assert!(sv3.get(&MclFr::from(1)) == &MclFr::from(2)); + assert!(sv3.get(&MclFr::from(2)) == &MclFr::from(3)); + assert!(sv3.get(&MclFr::from(3)) == &MclFr::from(4)); + } +} diff --git a/src/building_block/mcl/mod.rs b/src/building_block/mcl/mod.rs new file mode 100644 index 0000000..ff2d562 --- /dev/null +++ b/src/building_block/mcl/mod.rs @@ -0,0 +1,11 @@ +pub mod mcl_g1; +pub mod mcl_g2; +pub mod mcl_gt; +pub mod mcl_fr; +pub mod mcl_initializer; +pub mod mcl_sparse_matrix; +pub mod mcl_sparse_vec; +pub mod pairing; +pub mod polynomial; +pub mod qap; + diff --git a/src/building_block/mcl/pairing.rs b/src/building_block/mcl/pairing.rs new file mode 100644 index 0000000..b4b4b03 --- /dev/null +++ b/src/building_block/mcl/pairing.rs @@ -0,0 +1,130 @@ +use mcl_rust::*; +use crate::building_block::mcl::{ + mcl_g1::MclG1, + mcl_g2::MclG2, + mcl_gt::MclGT, +}; + +#[derive(Debug, Clone)] +pub struct Pairing; + +impl Pairing { + pub fn e(&self, p1: &MclG1, p2: &MclG2) -> MclGT { + let mut v = GT::zero(); + pairing(&mut v, &p1.v, &p2.v); + MclGT::from(&v) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_initializer::MclInitializer, + }; + + fn test( + pairing: &Pairing, + pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT, + p1: &MclG1, + p2: &MclG2, + ) -> bool { + let ten = MclFr::from(10); + let ten_p1s = p1 * &ten; + + // test e(p1 + ten_p1s, p2) = e(p1, p2) e(ten_p1s, p2) + let lhs = pair(pairing, &(p1 + &ten_p1s), p2); + let rhs1 = pair(pairing, p1, p2); + let rhs2 = pair(pairing, &ten_p1s, p2); + + let rhs = rhs1 * rhs2; + + lhs == rhs + } + + fn test_with_generators( + pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT, + ) { + let pairing = &Pairing; + let p1 = MclG1::g(); + let p2 = MclG2::g(); + let res = test(pairing, pair, &p1, &p2); + assert!(res); + } + + fn test_with_random_points( + pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT, + ) { + let mut errors = 0; + let num_tests = 1; + + for _ in 0..num_tests { + let pairing = &Pairing; + let p1 = MclG1::get_random_point(); + let p2 = MclG2::get_random_point(); + let res = test(pairing, pair, &p1, &p2); + if res == false { + errors += 1; + } + } + assert!(errors == 0); + } + + fn test_plus_to_mul(pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT, + ) { + let pairing = &Pairing; + let one = &MclG2::g(); + + let p = &(MclG1::g() + MclG1::g()); + + let lhs = { + let p_plus_p = p + p; + pair(pairing, &p_plus_p, one) + }; + + let rhs = { + let a = &pair(pairing, &p, one); + a * a + }; + assert!(lhs == rhs); + } + + #[test] + fn test_weil_pairing_with_generators() { + MclInitializer::init(); + test_with_generators(&Pairing::e); + } + + #[test] + fn test_weil_pairing_with_random_points() { + MclInitializer::init(); + test_with_random_points(&Pairing::e); + } + + #[test] + fn test_tate_pairing_with_test_plus_to_mul() { + MclInitializer::init(); + test_plus_to_mul(&Pairing::e); + } + + #[test] + fn test_signature_verification() { + MclInitializer::init(); + + let pairing = &Pairing; + let g1 = &MclG1::g(); + let sk = &MclFr::from(2); + let pk = &(g1 * sk); + + let m = &b"hamburg steak".to_vec(); + let hash_m = &MclG2::hash_and_map(m); + + // e(pk, H(m)) = e(g1*sk, H(m)) = e(g1, sk*H(m)) + let lhs = pairing.e(pk, hash_m); + let rhs = pairing.e(g1, &(hash_m * sk)); + + assert!(lhs == rhs); + } +} + diff --git a/src/building_block/mcl/polynomial.rs b/src/building_block/mcl/polynomial.rs new file mode 100644 index 0000000..2ab39b5 --- /dev/null +++ b/src/building_block/mcl/polynomial.rs @@ -0,0 +1,1233 @@ +use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_g1::MclG1, + mcl_g2::MclG2, + mcl_sparse_vec::MclSparseVec, +}; +use num_traits::Zero; +use std::{ + fmt::{Debug, Formatter}, + ops::{ + Add, + Deref, + Mul, + Sub, + AddAssign, MulAssign, + }, + convert::From, +}; + +// TODO use SparseVec instead of Vec to hold coeffs +#[derive(Clone)] +pub struct Polynomial { + pub coeffs: Vec, // e.g. 2x^3 + 5x + 9 -> [9, 5, 0, 2] + _private: (), // to force using new() +} + +impl Deref for Polynomial { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.coeffs + } +} + +impl From<&MclSparseVec> for Polynomial { + fn from(vec: &MclSparseVec) -> Self { + let mut i = MclFr::zero(); + let mut coeffs = vec![]; + while i < vec.size { + let v = vec.get(&i); + coeffs.push(v.clone()); + i.inc(); + } + let p = Polynomial::new(&coeffs); + p.normalize() + } +} + +impl PartialEq for Polynomial { + fn eq(&self, rhs: &Polynomial) -> bool { + let (smaller, larger) = if self.coeffs.len() < rhs.coeffs.len() { + (&self.coeffs, &rhs.coeffs) + } else { + (&rhs.coeffs, &self.coeffs) + }; + + // if larger is superset, it contains other non-zero terms + if smaller.len() != larger.len() { return false; } + + // check if smaller is a subset of larger + for i in 0..smaller.len() { + if &smaller[i] != &larger[i] { return false; } + } + true + } +} + +impl Debug for Polynomial { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + let one = &MclFr::from(1); + let mut terms = vec![]; + let last_idx = self.coeffs.len() - 1; + + for (i, coeff) in self.coeffs.iter().rev().enumerate() { + if !coeff.is_zero() { + let mut s = String::new(); + // write number + if coeff != one || i == self.coeffs.len() - 1 { + s.push_str(&format!("{:?}", coeff)); + } + + // if not the constant term, write variable after number + if i < last_idx { + s.push_str("x"); + // write exponent if x^2 or higher + if i < last_idx - 1 { // second to last corresponds to x^1 + s.push_str(&format!("^{}", self.coeffs.len() - 1 - i)); + } + } + terms.push(s); + } + } + + let expr = terms.iter().map(|x| format!("{}", x)).collect::>().join(" + "); + write!(f, "{}", expr) + } +} + +#[derive(Debug)] +pub enum DivResult { + Quotient(Polynomial), + QuotientRemainder((Polynomial, Polynomial)), +} + +impl Zero for Polynomial { + fn zero() -> Self { + let coeffs = &vec![MclFr::zero()]; + Polynomial::new(coeffs) + } + + fn is_zero(&self) -> bool { + self.coeffs.len() == 1 && self.coeffs[0].is_zero() + } +} + +impl Polynomial { + pub fn new(coeffs: &Vec) -> Self { + if coeffs.len() == 0 { panic!("coeffs is empty"); } + let x = Polynomial { + coeffs: coeffs.clone(), + _private: () + }; + x.normalize() + } + + // trim trailing zero-coeff terms + fn normalize(&self) -> Polynomial { + let mut new_len = self.coeffs.len(); + for i in 0..(self.coeffs.len() - 1) { // seek from end to beg and always keep the 0th element + let coeff = &self.coeffs[&self.coeffs.len() - 1 - i]; + if !coeff.is_zero() { break; } + new_len -= 1; + } + + let mut norm_coeffs = vec![]; + for coeff in &self.coeffs[0..new_len] { + norm_coeffs.push(coeff.clone()); + } + Polynomial { coeffs: norm_coeffs, _private: () } + } + + pub fn plus(&self, rhs: &Polynomial) -> Polynomial { + let (smaller, larger) = if self.coeffs.len() < rhs.coeffs.len() { + (&self.coeffs, &rhs.coeffs) + } else { + (&rhs.coeffs, &self.coeffs) + }; + + let mut coeffs = vec![]; + for i in 0..larger.len() { + if i < smaller.len() { + coeffs.push(&smaller[i] + &larger[i]); + } else { + coeffs.push(larger[i].clone()); + } + } + let x = Polynomial { coeffs, _private: () }; + x.normalize() // normalizing b/c addition can make term coefficect zero + } + + pub fn multiply_by(&self, rhs: &Polynomial) -> Polynomial { + // degree of polynomial is coeffs.len - 1 + let self_degree = self.coeffs.len() - 1; + let rhs_degree = rhs.coeffs.len() - 1; + + // coeffs len of the mul result is sum of self and rhs degrees + 1 + let new_len = self_degree + rhs_degree + 1; + let mut coeffs = vec![MclFr::zero(); new_len]; + + for i in 0..self.coeffs.len() { + for j in 0..rhs.coeffs.len() { + let coeff = &self.coeffs[i] * &rhs.coeffs[j]; + let degree = i + j; + coeffs[degree] = &coeffs[degree] + coeff; + } + } + Polynomial { coeffs, _private: () } + } + + // not supporting cases where rhs degree > lhs degree + pub fn minus(&self, rhs: &Polynomial) -> Polynomial { + assert!(self.coeffs.len() >= rhs.coeffs.len()); + let mut coeffs = self.coeffs.clone(); + + for i in 0..rhs.coeffs.len() { + coeffs[i] = &coeffs[i] - &rhs.coeffs[i]; + } + let p = Polynomial { coeffs, _private: () }; + p.normalize() + } + + pub fn divide_by(&self, rhs: &Polynomial) -> DivResult { + let mut dividend = self.clone(); + + let divisor = rhs; + let quotient_degree = dividend.len() - divisor.len(); + let divisor_coeff = &divisor[divisor.len() - 1]; + assert!(!divisor_coeff.is_zero(), "found zero coeff at highest index. use Polynomial constructor"); + + let mut quotient_coeffs = vec![MclFr::zero(); quotient_degree + 1]; + + while !dividend.is_zero() && dividend.len() >= divisor.len() { + let dividend_coeff = ÷nd[dividend.len() - 1]; + + // create a term to multiply with divisor + let term_coeff = dividend_coeff * divisor_coeff.inv(); + let term_degree = dividend.len() - divisor.len(); + let mut term_vec = vec![MclFr::zero(); term_degree + 1]; + term_vec[term_degree] = term_coeff.clone(); + let term_poly = Polynomial::new(&term_vec); + + // reflect term coeff to result quotient + quotient_coeffs[term_degree] = term_coeff; + + let poly2subtract = divisor.multiply_by(&term_poly); + + // update dividend for the next round + dividend = dividend.sub(&poly2subtract); + } + + if dividend.is_zero() { + DivResult::Quotient(Polynomial { coeffs: quotient_coeffs, _private: () }) + } else { + let quotient = Polynomial { coeffs: quotient_coeffs, _private: () }; + DivResult::QuotientRemainder((quotient, dividend)) + } + } + + pub fn eval_at(&self, x: &MclFr) -> MclFr { + let mut multiplier = MclFr::from(1); + let mut sum = MclFr::zero(); + + for coeff in &self.coeffs { + sum = sum + coeff * &multiplier; + multiplier = &multiplier * x; + } + sum + } + + pub fn eval_from_1_to_n(&self, n: &MclFr) -> MclSparseVec { + let one = &MclFr::from(1); + + let mut vec = MclSparseVec::new(n); + let mut i = MclFr::zero(); + while &i < n { + i.inc(); + let res = self.eval_at(&i); + vec.set(&(&i - one), &res); + } + vec + } + + pub fn degree(&self) -> MclFr { + if self.coeffs.len() == 0 { + panic!("should have at least 1 coeff. check code"); + } + MclFr::from(self.coeffs.len() - 1) + } + + #[allow(non_snake_case)] + pub fn eval_with_g1_hidings( + &self, + powers: &[MclG1], + ) -> MclG1 { + let mut sum = MclG1::zero(); + for i in 0..self.coeffs.len() { + sum = sum + (&powers[i] * &self.coeffs[i]); + } + sum + } + + #[allow(non_snake_case)] + pub fn eval_with_g2_hidings( + &self, + powers: &[MclG2], + ) -> MclG2 { + let mut sum = MclG2::zero(); + for i in 0..self.coeffs.len() { + sum = sum + (&powers[i] * &self.coeffs[i]); + } + sum + } + + pub fn to_sparse_vec(&self, size: usize) -> MclSparseVec { + let size = MclFr::from(size); + let mut vec = MclSparseVec::new(&size); + + for (i, coeff) in self.coeffs.iter().enumerate() { + let i = MclFr::from(i); + vec.set(&i, coeff); + } + vec + } +} + +macro_rules! impl_add { + ($rhs: ty, $target: ty) => { + impl<'a> Add<$rhs> for $target { + type Output = Polynomial; + + fn add(self, rhs: $rhs) -> Self::Output { + self.plus(&rhs) + } + } + }; +} +impl_add!(Polynomial, Polynomial); +impl_add!(&Polynomial, Polynomial); +impl_add!(Polynomial, &Polynomial); +impl_add!(&Polynomial, &Polynomial); + +impl AddAssign for Polynomial { + fn add_assign(&mut self, rhs: Polynomial) { + *self = &*self + &rhs; + } +} + +impl AddAssign<&Polynomial> for Polynomial { + fn add_assign(&mut self, rhs: &Polynomial) { + *self = &*self + rhs; + } +} + +macro_rules! impl_mul { + ($rhs: ty, $target: ty) => { + impl<'a> Mul<$rhs> for $target { + type Output = Polynomial; + + fn mul(self, rhs: $rhs) -> Self::Output { + self.multiply_by(&rhs) + } + } + }; +} +impl_mul!(Polynomial, Polynomial); +impl_mul!(Polynomial, &Polynomial); +impl_mul!(&Polynomial, Polynomial); +impl_mul!(&Polynomial, &Polynomial); + +impl<'a> Mul<&MclFr> for &Polynomial { + type Output = Polynomial; + + fn mul(self, rhs: &MclFr) -> Self::Output { + Polynomial { + coeffs: self.coeffs.iter().map(|coeff| coeff * rhs).collect(), + _private: (), + } + } +} + +impl MulAssign for Polynomial { + fn mul_assign(&mut self, rhs: Polynomial) { + *self = &*self * &rhs; + } +} + +impl MulAssign<&Polynomial> for Polynomial { + fn mul_assign(&mut self, rhs: &Polynomial) { + *self = &*self * rhs; + } +} + +impl<'a> Sub<&Polynomial> for Polynomial { + type Output = Polynomial; + + fn sub(self, rhs: &Polynomial) -> Self::Output { + self.minus(rhs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + use super::DivResult::{Quotient, QuotientRemainder}; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn test_to_sparse_vec() { + MclInitializer::init(); + + let zero = &MclFr::zero(); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + let four = &MclFr::from(4); + + // 2x + 3 + let coeffs = vec![ + two.clone(), + three.clone(), + ]; + let p = Polynomial::new(&coeffs); + let vec = p.to_sparse_vec(four.to_usize()); + + assert_eq!(&vec.size, four); + assert_eq!(vec.get(zero), two); + assert_eq!(vec.get(one), three); + assert_eq!(vec.get(two), zero); + assert_eq!(vec.get(three), zero); + } + + #[test] + fn test_degree() { + MclInitializer::init(); + + // degree 0 + { + let p = Polynomial::new(&vec![ + MclFr::from(2), + ]); + assert_eq!(p.degree(), MclFr::zero()); + + // 0 coeff case + let p = p.normalize(); + assert_eq!(p.degree(), MclFr::zero()); + } + // degree 1 + { + let p = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(3), + ]); + assert_eq!(p.degree(), MclFr::from(1)); + } + } + + #[test] + fn test_from_sparse_vec() { + MclInitializer::init(); + + let zero = &MclFr::from(0); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let three = &MclFr::from(3); + { + let mut vec = MclSparseVec::new(&MclFr::from(2)); + vec.set(zero, two); + vec.set(one, three); + + let p = Polynomial::from(&vec); + assert_eq!(&p.degree(), one); + assert_eq!(&p.coeffs[0], two); + assert_eq!(&p.coeffs[1], three); + } + { + let mut vec = MclSparseVec::new(&MclFr::from(2)); + vec.set(zero, two); + vec.set(one, zero); + + let p = Polynomial::from(&vec); + assert_eq!(&p.degree(), zero); + assert_eq!(&p.coeffs[0], two); + } + } + + #[test] + fn test_eval_at() { + MclInitializer::init(); + + let zero = MclFr::from(0); + let one = MclFr::from(1); + let two = MclFr::from(2); + { // 8 + let zero = MclFr::from(0); + let eight = MclFr::from(8); + let p = Polynomial::new(&vec![ + eight.clone(), + ]); + assert_eq!(&p.eval_at(&zero), &eight); + } + { // 3x + 8 + let p = &Polynomial::new(&vec![ + MclFr::from(8), + MclFr::from(3), + ]); + assert_eq!(p.eval_at(&zero), MclFr::from(8)); + assert_eq!(p.eval_at(&one), MclFr::from(11)); + assert_eq!(p.eval_at(&two), MclFr::from(14)); + } + { // 2x^2 + 3x + 8 + let p = &Polynomial::new(&vec![ + MclFr::from(8), + MclFr::from(3), + MclFr::from(2), + ]); + assert_eq!(p.eval_at(&zero), MclFr::from(8)); + assert_eq!(p.eval_at(&one), MclFr::from(13)); + assert_eq!(p.eval_at(&two), MclFr::from(22)); + } + } + + #[test] + fn test_div_3_2_no_remainder() { + MclInitializer::init(); + { + /* + _____________ + x+2 ) x² + x + 5 + */ + let dividend = Polynomial::new(&vec![ + MclFr::from(5), + MclFr::from(1), + MclFr::from(1), + ]); + let divisor = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + ]); + let res = dividend.divide_by(&divisor); + if let QuotientRemainder((q, r)) = res { + assert_eq!(dividend, divisor * q + r); + } else if let Quotient(q) = res { + panic!("expected remainder, but got quotient {:?} w/ no remainder", q); + } else { + panic!("should not be visited"); + } + } + } + + #[test] + fn test_div_2_2() { + MclInitializer::init(); + { + /* + _________ + x+7 ) 2x + 3 + */ + let dividend = Polynomial::new(&vec![ + MclFr::from(3), + MclFr::from(2), + ]); + let divisor = Polynomial::new(&vec![ + MclFr::from(7), + MclFr::from(1), + ]); + let quotient = Polynomial::new(&vec![ + MclFr::from(2), + ]); + let remainder = Polynomial::new(&vec![ + -MclFr::from(11), + ]); + let res = dividend.divide_by(&divisor); + if let QuotientRemainder((q, r)) = res { + assert!(q == quotient); + assert!(r == remainder); + } else if let Quotient(q) = res { + panic!("expected remainder, but got quotient {:?} w/ no remainder", q); + } else { + panic!("should not be visited"); + } + } + } + + #[test] + fn test_div_1_1() { + MclInitializer::init(); + { + /* + (32 * 10^-1) + 31461525105075714287668644304911579502614331500316582693562195219963148710711 + ________ + 10 ) 32 + 32 + -- + 0 + */ + let dividend = Polynomial::new(&vec![ + MclFr::from(32), + ]); + let divisor = Polynomial::new(&vec![ + MclFr::from(10), + ]); + let quotient = Polynomial::new(&vec![ + MclFr::from(32) * MclFr::from(10).inv(), + ]); + let res = dividend.divide_by(&divisor); + if let QuotientRemainder((q, r)) = res { + panic!("no remainder expected, but got q={:?}, r={:?})", q, r); + } else if let Quotient(q) = res { + assert!(q == quotient); + } else { + panic!("should not be visited"); + } + } + } + + #[test] + fn test_div_5_2() { + MclInitializer::init(); + { + let dividend = Polynomial::new(&vec![ + MclFr::from(5), // 0 + MclFr::from(0), // 1 + MclFr::from(0), // 2 + MclFr::from(4), // 3 + MclFr::from(7), // 4 + MclFr::from(0), // 5 + MclFr::from(3), // 6 + ]); + let divisor = Polynomial::new(&vec![ + MclFr::from(4), // 0 + MclFr::from(0), // 1 + MclFr::from(0), // 2 + MclFr::from(3), // 3 + MclFr::from(1), // 4 + ]); + let res = ÷nd.divide_by(&divisor); + if let QuotientRemainder((q, r)) = res { + assert_eq!(dividend, divisor * q + r); + } else if let Quotient(q) = res { + panic!("expected remainder, but got quotient {:?} w/ no remainder", q); + } else { + panic!("should not be visited"); + } + } + } + + fn gen_random_polynomial(degree: usize, max_coeff: u32) -> Polynomial { + let mut coeffs = vec![]; + + for _ in 0..degree { + let coeff: u32 = rand::thread_rng().gen_range(0..max_coeff); + coeffs.push(MclFr::from(coeff as i32)); + } + Polynomial::new(&coeffs) + } + + #[test] + fn test_div_random_divisible() { + MclInitializer::init(); + + let max_coeff = 100; + let min_divisor_degree = 30; + let max_divisor_degree = 100; + + for _ in 0..10 { + let divisor_degree = rand::thread_rng().gen_range(min_divisor_degree..max_divisor_degree); + let quotient_degree = rand::thread_rng().gen_range(1..divisor_degree); + + let divisor = &gen_random_polynomial(divisor_degree, max_coeff); + let quotient = &gen_random_polynomial(quotient_degree, max_coeff); + let dividend = divisor.multiply_by(quotient); + + match &÷nd.divide_by(divisor) { + Quotient(q) => { + assert!(q == quotient); + }, + QuotientRemainder(x) => { + panic!("unexpected remainder {:?}", x); + }, + }; + } + } + + #[test] + fn test_is_zero() { + MclInitializer::init(); + { + let a = Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(7), + ]); + assert!(!a.is_zero()); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(7), + ]); + assert!(!a.is_zero()); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + assert!(a.is_zero()); + } + } + + #[test] + fn test_sub_2_2() { + MclInitializer::init(); + + // subtract small poly + { + let a = Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(7), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(3), + MclFr::from(4), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(9), + MclFr::from(3), + ]); + assert!(a.sub(&b) == c); + } + // subtract bigger poly + { + let a = Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(7), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(15), + MclFr::from(8), + ]); + let c = Polynomial::new(&vec![ + -MclFr::from(3), + -MclFr::from(1), + ]); + assert!(a.sub(&b) == c); + } + // subtract the same poly + { + let a = &Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(7), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(0), + ]); + println!("res = {:?}", a.minus(a)); + assert!(a.minus(&a) == c); + } + } + + #[test] + #[should_panic] + fn test_bad_sub() { + MclInitializer::init(); + + std::panic::set_hook(Box::new(|_| {})); + let a = Polynomial::new(&vec![ + MclFr::from(7), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(3), + MclFr::from(4), + ]); + let _ = a.sub(&b); + } + + #[test] + fn test_debug_print() { + MclInitializer::init(); + { + let a = Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(45), + MclFr::from(67), + ]); + println!("{:?}", a); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(45), + ]); + println!("{:?}", a); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(12), + ]); + println!("{:?}", a); + } + } + + #[test] + fn test_new_non_empty_vec() { + MclInitializer::init(); + + let p = Polynomial::new(&vec![ + MclFr::from(12), + ]); + assert!(p.coeffs.len() == 1); + assert!(p.coeffs[0] == MclFr::from(12)); + } + + #[test] + #[should_panic] + fn test_new_empty_vec() { + MclInitializer::init(); + + std::panic::set_hook(Box::new(|_| {})); + Polynomial::new(&vec![]); + } + + #[test] + fn test_normalize() { + MclInitializer::init(); + { + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(0), + ]); + assert!(&a == &b); + assert!(a.coeffs.len() == 1); + assert!(&a.coeffs[0] == &MclFr::from(0)); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(1), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(1), + ]); + assert!(&a == &b); + assert!(a.coeffs.len() == 1); + assert!(&a.coeffs[0] == &MclFr::from(1)); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(1), + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(1), + ]); + assert!(&a == &b); + assert!(a.coeffs.len() == 1); + assert!(&a.coeffs[0] == &MclFr::from(1)); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(1), + MclFr::from(0), + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(1), + ]); + assert!(&a == &b); + assert!(a.coeffs.len() == 1); + assert!(&a.coeffs[0] == &MclFr::from(1)); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(1), + MclFr::from(0), + MclFr::from(0), + MclFr::from(1), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(1), + MclFr::from(0), + MclFr::from(0), + MclFr::from(1), + ]); + assert!(&a == &b); + assert!(a.coeffs.len() == 4); + assert!(&a.coeffs[0] == &MclFr::from(1)); + assert!(&a.coeffs[1] == &MclFr::from(0)); + assert!(&a.coeffs[2] == &MclFr::from(0)); + assert!(&a.coeffs[3] == &MclFr::from(1)); + } + } + + #[test] + fn test_eq() { + MclInitializer::init(); + { + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(0), + ]); + assert!(&a == &b); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(1), + ]); + assert!(&a != &b); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + ]); + assert!(&a == &b); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(1), + MclFr::from(2), + ]); + assert!(&a != &b); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + MclFr::from(0), + ]); + assert!(&a == &b); + } + { + let a = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(1), + MclFr::from(1), + ]); + assert!(&a != &b); + } + } + + #[test] + fn test_add() { + MclInitializer::init(); + // zero + zero + { + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let res = a.add(&b); + assert!(&res == &c); + } + // zero + non-zero + { + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(12), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(12), + ]); + let res = a.add(&b); + assert!(&res == &c); + } + // non-zero + non-zero + { + let a = Polynomial::new(&vec![ + MclFr::from(100), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(12), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(112), + ]); + let res = a.add(&b); + assert!(&res == &c); + } + } + + #[test] + fn test_add_zero_term() { + MclInitializer::init(); + { + let a = Polynomial::new(&vec![ + MclFr::from(3), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(3), + ]); + let res = a.add(&b); + assert!(&res == &c); + } + } + + #[test] + fn test_mul_deg_0_0() { + MclInitializer::init(); + { + // 0 * 0 + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let res = a.mul(&b); + assert!(&res == &c); + } + { + // 1 * 0 + let a = Polynomial::new(&vec![ + MclFr::from(1), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let res = a.mul(&b); + assert!(&res == &c); + } + { + // 0 * 1 + let a = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(1), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(0), + ]); + let res = a.mul(&b); + assert!(&res == &c); + } + { + // 2 * 3 + let a = Polynomial::new(&vec![ + MclFr::from(2), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(3), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(6), + ]); + let res = a.mul(&b); + println!("res={:?}", res); + assert!(&res == &c); + } + } + + #[test] + fn test_mul_deg_1_0() { + MclInitializer::init(); + { + // (2x - 3) * 4 + let a = Polynomial::new(&vec![ + MclFr::from(3), + MclFr::from(2), + ]); + let b = Polynomial::new(&vec![ + MclFr::from(4), + ]); + let c = Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(8), + ]); + let res = a.mul(&b); + assert!(&res == &c); + } + } + + #[test] + fn test_mul_deg_1_1() { + MclInitializer::init(); + { + // 2x + 3 + let a = &Polynomial::new(&vec![ + MclFr::from(3), + MclFr::from(2), + ]); + // 5x + 4 + let b = &Polynomial::new(&vec![ + MclFr::from(4), + MclFr::from(5), + ]); + // 10x^2 + 23x + 12 + let c = &Polynomial::new(&vec![ + MclFr::from(12), + MclFr::from(23), + MclFr::from(10), + ]); + let res = a.multiply_by(&b); + println!("({:?})({:?}) = {:?}", a, b, res); + assert!(&res == c); + } + } + + #[test] + fn test_mul_const() { + MclInitializer::init(); + { + // 2x + 3 + let a = Polynomial::new(&vec![ + MclFr::from(3), + MclFr::from(2), + ]); + let ten = MclFr::from(10); + + // 20x + 30 + let exp = &Polynomial::new(&vec![ + MclFr::from(30), + MclFr::from(20), + ]); + + let act = &a * &ten; + + println!("({:?}) * {:?} = {:?}", a, &ten, &act); + assert!(&act == exp); + } + } + + #[test] + fn test_eval_from_1_to_n() { + MclInitializer::init(); + { + // evaluating for the same degree as the polynomial degree + let p = Polynomial::new(&vec![ + MclFr::from(2), + MclFr::from(3), + MclFr::from(5), + ]); + let three = &MclFr::from(3); + let vec = p.eval_from_1_to_n(three); + assert_eq!(&vec.size, three); + assert_eq!(vec.get(&MclFr::zero()), &MclFr::from(10)); + assert_eq!(vec.get(&MclFr::from(1)), &MclFr::from(28)); + assert_eq!(vec.get(&MclFr::from(2)), &MclFr::from(56)); + } + { + // evaluating for larger degree than the polynomial degree + let zero = &MclFr::from(0); + let three = &MclFr::from(3); + + let p = Polynomial::new(&vec![ + zero.clone(), + ]); + let vec = p.eval_from_1_to_n(three); + assert_eq!(&vec.size, three); + assert_eq!(vec.get(zero), zero); + assert_eq!(vec.get(&MclFr::from(1)), zero); + assert_eq!(vec.get(&MclFr::from(2)), zero); + } + } + + #[test] + fn test_eval_with_g1_hidings_1() { + MclInitializer::init(); + + let s = MclFr::from(3); + let s0g = &MclG1::g(); + let s1g = s0g * &s; + let s2g = s0g * &s * &s; + let s3g = s0g * &s * &s * &s; + let pows = vec![ + s0g.clone(), + s1g.clone(), + s2g.clone(), + s3g.clone(), + ]; + let two = MclFr::from(2); + let three = MclFr::from(3); + let four = MclFr::from(4); + let five = MclFr::from(5); + + let exp = + s0g * &two + + &s1g * &three + + &s2g * &four + + &s3g * &five + ; + // 5x^3 + 4x^2 + 3x + 2 + let p = Polynomial::new(&vec![ + two, + three, + four, + five, + ]); + let act = p.eval_with_g1_hidings(&pows); + + assert!(act == exp); + } + + #[test] + fn test_eval_with_g1_order() { + MclInitializer::init(); + + let s = MclFr::from(3); + + let e1549 = MclFr::from(1549); + let e3361 = MclFr::from(3361); + let e3607 = MclFr::from(3607); + let e822 = MclFr::from(822); + let e1990 = MclFr::from(1990); + let e496 = MclFr::from(496); + let e1698 = MclFr::from(1698); + let e2362 = MclFr::from(2362); + let e3670 = MclFr::from(3670); + + // 3670x^8 + 2362x^7 + 1698x^6 + 496x^5 + 1990x^4 + 822x^3 + 3607x^2 + 3361x + 1549 + let p = Polynomial::new(&vec![ + e1549, + e3361, + e3607, + e822, + e1990, + e496, + e1698, + e2362, + e3670, + ]); + assert!(p.eval_at(&s) == MclFr::from(30830413)); + + } +} diff --git a/src/building_block/mcl/qap/config.rs b/src/building_block/mcl/qap/config.rs new file mode 100644 index 0000000..4f2bfd7 --- /dev/null +++ b/src/building_block/mcl/qap/config.rs @@ -0,0 +1 @@ +pub type SignalId = u128; diff --git a/src/building_block/mcl/qap/constraint.rs b/src/building_block/mcl/qap/constraint.rs new file mode 100644 index 0000000..708ff55 --- /dev/null +++ b/src/building_block/mcl/qap/constraint.rs @@ -0,0 +1,28 @@ +use crate::building_block::mcl::mcl_sparse_vec::MclSparseVec; + +#[derive(Clone)] +pub struct Constraint { + pub a: MclSparseVec, + pub b: MclSparseVec, + pub c: MclSparseVec, +} + +impl Constraint { + pub fn new( + a: &MclSparseVec, + b: &MclSparseVec, + c: &MclSparseVec, + ) -> Self { + let a = a.clone(); + let b = b.clone(); + let c = c.clone(); + Constraint { a, b, c } + } +} + +impl std::fmt::Debug for Constraint { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?} . w * {:?} . w - {:?} . w = 0", self.a, self.b, self.c) + } +} + diff --git a/src/building_block/mcl/qap/equation_parser.rs b/src/building_block/mcl/qap/equation_parser.rs new file mode 100644 index 0000000..ea54330 --- /dev/null +++ b/src/building_block/mcl/qap/equation_parser.rs @@ -0,0 +1,686 @@ +use nom::{ + IResult, + branch::alt, + bytes::complete::tag, + character::complete::{ alpha1, char, multispace0, one_of }, + combinator::{ opt, recognize }, + multi::{ many0, many1 }, + sequence::{ tuple, delimited, terminated }, +}; +use crate::building_block::mcl::mcl_fr::MclFr; +use crate::zk::w_trusted_setup::qap::config::SignalId; +use std::cell::Cell; + +#[derive(Debug, PartialEq, Clone)] +pub enum MathExpr { + Equation(Box, Box), + Num(MclFr), + Var(String), + Mul(SignalId, Box, Box), + Add(SignalId, Box, Box), + Div(SignalId, Box, Box), + Sub(SignalId, Box, Box), +} + +#[derive(Debug)] +pub struct Equation { + pub lhs: MathExpr, + pub rhs: MclFr, +} + +pub struct EquationParser(); + +macro_rules! set_next_id { + ($signal_id: expr) => { + $signal_id.set($signal_id.get() + 1); + }; +} + +impl EquationParser { + fn num_str_to_field_elem(s: &str) -> MclFr { + MclFr::from(s) + } + + fn variable<'a>() -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a { + |input| { + let (input, s) = + delimited( + multispace0, + recognize( + terminated(alpha1, many0(one_of("0123456789"))) + ), + multispace0 + )(input)?; + + Ok((input, MathExpr::Var(s.to_string()))) + } + } + + fn decimal() -> impl Fn(&str) -> IResult<&str, MathExpr> { + |input| { + let (input, s) = + delimited( + multispace0, + recognize( + tuple(( + opt(char('-')), + many1( + one_of("0123456789") + ), + )), + ), + multispace0 + )(input)?; + + let n = EquationParser::num_str_to_field_elem(s); + Ok((input, MathExpr::Num(n))) + } + } + + // ::= | | '(' ')' + fn term2<'a>(signal_id: &'a Cell) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a { + |input| { + let (input, node) = alt(( + EquationParser::variable(), + EquationParser::decimal(), + delimited( + delimited(multispace0, char('('), multispace0), + EquationParser::expr(signal_id), + delimited(multispace0, char(')'), multispace0), + ), + ))(input)?; + + Ok((input, node)) + } + } + + // ::= [ ('*'|'/') ]* + fn term1<'a>(signal_id: &'a Cell) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a { + |input| { + let rhs = tuple((alt((char('*'), char('/'))), EquationParser::term2(signal_id))); + let (input, (lhs, rhs)) = tuple(( + EquationParser::term2(signal_id), + many0(rhs), + ))(input)?; + + if rhs.len() == 0 { + Ok((input, lhs)) + } else { + // translate rhs vector to Mul>>.. + let rhs_head = &rhs[0]; + let rhs = rhs.iter().skip(1).fold(rhs_head.1.clone(), |acc, x| { + match x { + ('*', node) => { + set_next_id!(signal_id); + MathExpr::Mul(signal_id.get(), Box::new(acc), Box::new(node.clone())) + }, + ('/', node) => { + set_next_id!(signal_id); + MathExpr::Div(signal_id.get(), Box::new(acc), Box::new(node.clone())) + }, + (op, _) => panic!("unexpected operator encountered in term1 {}", op), + } + }); + + set_next_id!(signal_id); + let node = if rhs_head.0 == '*' { + MathExpr::Mul(signal_id.get(), Box::new(lhs), Box::new(rhs)) + } else { + MathExpr::Div(signal_id.get(), Box::new(lhs), Box::new(rhs)) + }; + Ok((input, node)) + } + } + } + + // ::= [ ('+'|'-') ]* + fn expr<'a>(signal_id: &'a Cell) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a { + |input| { + let rhs = tuple((alt((char('+'), char('-'))), EquationParser::term1(signal_id))); + let (input, (lhs, rhs)) = tuple(( + EquationParser::term1(signal_id), + many0(rhs), + ))(input)?; + + if rhs.len() == 0 { + Ok((input, lhs)) + } else { + // translate rhs vector to Add>>.. + let rhs_head = &rhs[0]; + let rhs = rhs.iter().skip(1).fold(rhs_head.1.clone(), |acc, x| { + match x { + ('+', node) => { + set_next_id!(signal_id); + MathExpr::Add(signal_id.get(), Box::new(acc), Box::new(node.clone())) + }, + ('-', node) => { + set_next_id!(signal_id); + MathExpr::Sub(signal_id.get(), Box::new(acc), Box::new(node.clone())) + }, + (op, _) => panic!("unexpected operator encountered in expr: {}", op), + } + }); + + set_next_id!(signal_id); + let node = if rhs_head.0 == '+' { + MathExpr::Add(signal_id.get(), Box::new(lhs), Box::new(rhs)) + } else { + MathExpr::Sub(signal_id.get(), Box::new(lhs), Box::new(rhs)) + }; + Ok((input, node)) + } + } + } + + // ::= '=' + fn equation<'a>(signal_id: &'a Cell) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a { + |input| { + let (input, out) = + tuple(( + multispace0, + EquationParser::expr(signal_id), + multispace0, + tag("=="), + multispace0, + EquationParser::decimal(), + multispace0, + ))(input)?; + + let lhs = out.1; + let rhs = out.5; + Ok((input, MathExpr::Equation(Box::new(lhs), Box::new(rhs)))) + } + } + // ::= [ ('*'|'/') ]* + // ::= | | '(' ')' + // ::= [ ('+'|'-') ]* + // ::= '==' + pub fn parse<'a>(input: &'a str) -> Result { + let signal_id = Cell::new(0); + let expr = EquationParser::equation(&signal_id); + match expr(input) { + Ok((_, expr)) => { + match expr { + MathExpr::Equation(lhs, rhs) => { + if let MathExpr::Num(n) = *rhs { + Ok(Equation { lhs: *lhs, rhs: n }) + } else { + Err(format!("Equation has unexpected RHS: {:?}", rhs)) + } + }, + _ => Err(format!("Unexpected parse result: {:?}", expr)) + } + + }, + Err(x) => Err(x.to_string()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn test_decimal() { + MclInitializer::init(); + match EquationParser::parse("123 == 123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(123))); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_decimal_with_spaces() { + MclInitializer::init(); + match EquationParser::parse(" 123 == 123 ") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(123))); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_neg_decimal_below_order() { + MclInitializer::init(); + match EquationParser::parse("-123 == -123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(-123))); + assert_eq!(eq.rhs, MclFr::from(-123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_neg_decimal_above_order() { + MclInitializer::init(); + match EquationParser::parse("-123 == -123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(-123))); + assert_eq!(eq.rhs, MclFr::from(-123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_1_char_variable() { + MclInitializer::init(); + for s in vec!["x", "x1", "x0", "xy", "xy1"] { + match EquationParser::parse(&format!("{} == 123", s)) { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Var(s.to_string())); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + } + + #[test] + fn test_1_char_variable_with_spaces() { + MclInitializer::init(); + for s in vec!["x", "x1", "x0", "xy", "xy1"] { + match EquationParser::parse(&format!(" {} == 123 ", s)) { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Var(s.to_string())); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + } + + #[test] + fn test_simple_add_expr() { + MclInitializer::init(); + match EquationParser::parse("123+456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_add_expr_with_1_var() { + MclInitializer::init(); + for s in vec!["x", "x1", "x0", "xy", "xy1"] { + match EquationParser::parse(&format!("{}+456==123", s)) { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(1, + Box::new(MathExpr::Var(s.to_string())), + Box::new(MathExpr::Num(MclFr::from(456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + } + + #[test] + fn test_simple_add_expr_with_2_vars() { + MclInitializer::init(); + for (a,b) in vec![("x", "y"), ("x1", "y1"), ("xxx1123", "yyy123443")] { + match EquationParser::parse(&format!("{}+{}==123", a, b)) { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(1, + Box::new(MathExpr::Var(a.to_string())), + Box::new(MathExpr::Var(b.to_string())), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + } + + #[test] + fn test_simple_add_expr_incl_neg() { + MclInitializer::init(); + match EquationParser::parse("123+-456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(-456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_sub_expr() { + MclInitializer::init(); + match EquationParser::parse("123-456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Sub(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_sub_expr_1_var() { + MclInitializer::init(); + for s in vec!["x", "x1", "x0", "xy", "xy1"] { + match EquationParser::parse(&format!("123-{}==123", s)) { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Sub(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Var(s.to_string())), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + } + + #[test] + fn test_simple_sub_expr_incl_neg1() { + MclInitializer::init(); + match EquationParser::parse("-123-456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Sub(1, + Box::new(MathExpr::Num(MclFr::from(-123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_sub_expr_incl_neg1_1_var() { + MclInitializer::init(); + for s in vec!["x", "x1", "x0", "xy", "xy1"] { + match EquationParser::parse(&format!("-123-{}==123", s)) { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Sub(1, + Box::new(MathExpr::Num(MclFr::from(-123))), + Box::new(MathExpr::Var(s.to_string())), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + } + + #[test] + fn test_simple_sub_expr_incl_neg2() { + MclInitializer::init(); + match EquationParser::parse("123--456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Sub(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(-456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_sub_expr_incl_neg2_with_spaces() { + MclInitializer::init(); + match EquationParser::parse("123 - -456 == 123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Sub(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(-456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_sub_expr_incl_neg2_with_spaces_1_var() { + MclInitializer::init(); + match EquationParser::parse("x - -456 == 123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Sub(1, + Box::new(MathExpr::Var("x".to_string())), + Box::new(MathExpr::Num(MclFr::from(-456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_mul_expr() { + MclInitializer::init(); + match EquationParser::parse("123*456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Mul(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_mul_expr_incl_neg1() { + MclInitializer::init(); + match EquationParser::parse("123*-456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Mul(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(-456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_mul_expr_incl_neg2() { + MclInitializer::init(); + match EquationParser::parse("-123*456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Mul(1, + Box::new(MathExpr::Num(MclFr::from(-123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_mul_expr_incl_neg() { + MclInitializer::init(); + match EquationParser::parse("123*-456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Mul(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(-456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_simple_div_expr() { + MclInitializer::init(); + match EquationParser::parse("123/456==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Div(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_add_and_mul_expr() { + MclInitializer::init(); + match EquationParser::parse("123+456*789==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(2, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Mul(1, + Box::new(MathExpr::Num(MclFr::from(456))), + Box::new(MathExpr::Num(MclFr::from(789))), + )), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_add_mul_div_expr() { + MclInitializer::init(); + match EquationParser::parse("111/222+333*444==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(3, + Box::new(MathExpr::Div(1, + Box::new(MathExpr::Num(MclFr::from(111))), + Box::new(MathExpr::Num(MclFr::from(222))), + )), + Box::new(MathExpr::Mul(2, + Box::new(MathExpr::Num(MclFr::from(333))), + Box::new(MathExpr::Num(MclFr::from(444))), + )), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_paren_add_and_mul_expr() { + MclInitializer::init(); + match EquationParser::parse("(123+456)*789==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Mul(2, + Box::new(MathExpr::Add(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )), + Box::new(MathExpr::Num(MclFr::from(789))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_paren_add_and_mul_expr_with_spaces() { + MclInitializer::init(); + match EquationParser::parse(" (123 + 456) * 789 == 123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Mul(2, + Box::new(MathExpr::Add(1, + Box::new(MathExpr::Num(MclFr::from(123))), + Box::new(MathExpr::Num(MclFr::from(456))), + )), + Box::new(MathExpr::Num(MclFr::from(789))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_paren_add_mul_sub_expr() { + MclInitializer::init(); + match EquationParser::parse("(111+222)*(333-444)==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Mul(3, + Box::new(MathExpr::Add(1, + Box::new(MathExpr::Num(MclFr::from(111))), + Box::new(MathExpr::Num(MclFr::from(222))), + )), + Box::new(MathExpr::Sub(2, + Box::new(MathExpr::Num(MclFr::from(333))), + Box::new(MathExpr::Num(MclFr::from(444))), + )), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_multiple_paren() { + MclInitializer::init(); + match EquationParser::parse("((111+222))==123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(1, + Box::new(MathExpr::Num(MclFr::from(111))), + Box::new(MathExpr::Num(MclFr::from(222))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn test_multiple_paren_with_spaces() { + MclInitializer::init(); + match EquationParser::parse(" ( (111+222) ) == 123") { + Ok(eq) => { + assert_eq!(eq.lhs, MathExpr::Add(1, + Box::new(MathExpr::Num(MclFr::from(111))), + Box::new(MathExpr::Num(MclFr::from(222))), + )); + assert_eq!(eq.rhs, MclFr::from(123)); + }, + Err(_) => panic!(), + } + } + + #[test] + fn blog_post_1_example_1() { + MclInitializer::init(); + let expr = "(x * x * x) + x + 5 == 35"; + match EquationParser::parse(expr) { + Ok(eq) => { + println!("{} -> {:?}", expr, eq); + }, + Err(_) => panic!(), + } + } +} diff --git a/src/building_block/mcl/qap/gate.rs b/src/building_block/mcl/qap/gate.rs new file mode 100644 index 0000000..db610ad --- /dev/null +++ b/src/building_block/mcl/qap/gate.rs @@ -0,0 +1,251 @@ +use crate::building_block::mcl::qap::{ + equation_parser::{ + Equation, + MathExpr, + }, + term::Term, +}; + +pub struct Gate { + pub a: Term, + pub b: Term, + pub c: Term, +} + +impl std::fmt::Debug for Gate { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?} * {:?} = {:?}", self.a, self.b, self.c) + } +} + +impl Gate { + // traverse the Equation tree generating statement at each Add/Mul node + fn traverse_lhs( + expr: &MathExpr, gates: &mut Vec + ) -> Term { + match expr { + MathExpr::Num(n) => Term::Num(n.clone()), + MathExpr::Var(s) => Term::Var(s.clone()), + + MathExpr::Add(signal_id, left, right) => { + let a = Gate::traverse_lhs(left, gates); + let b = Gate::traverse_lhs(right, gates); + let c = Term::TmpVar(*signal_id); + // a + b = c + // -> (a + b) * 1 = c + let sum = Term::Sum(Box::new(a), Box::new(b)); + gates.push(Gate { a: sum, b: Term::One, c: c.clone() }); + c + }, + MathExpr::Mul(signal_id, left, right) => { + let a = Gate::traverse_lhs(left, gates); + let b = Gate::traverse_lhs(right, gates); + let c = Term::TmpVar(*signal_id); + gates.push(Gate { a, b, c: c.clone() }); + c + }, + MathExpr::Sub(signal_id, left, right) => { + let a = Gate::traverse_lhs(left, gates); + let b = Gate::traverse_lhs(right, gates); + let c = Term::TmpVar(*signal_id); + // a - b = c + // -> b + c = a + // -> (b + c) * 1 = a + let sum = Term::Sum(Box::new(b), Box::new(c.clone())); + gates.push(Gate { a: sum, b: Term::One, c: a.clone() }); + c + }, + MathExpr::Div(signal_id, left, right) => { + let a = Gate::traverse_lhs(left, gates); + let b = Gate::traverse_lhs(right, gates); + let c = Term::TmpVar(*signal_id); + // a / b = c + // -> b * c = a + gates.push(Gate { a: b, b: c.clone(), c: a.clone() }); + // send c to next as the original division does + c + }, + MathExpr::Equation(_lhs, _rhs) => { + panic!("should not be visited"); + } + } + } + + pub fn build(eq: &Equation) -> Vec { + let mut gates: Vec = vec![]; + let root = Gate::traverse_lhs(&eq.lhs, &mut gates); + let out_gate = Gate { a: root, b: Term::One, c: Term::Out }; + gates.push(out_gate); + gates + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_initializer::MclInitializer, + qap::equation_parser::EquationParser, + }; + + #[test] + fn test_build_add() { + MclInitializer::init(); + let input = "x + 4 == 9"; + let eq = EquationParser::parse(input).unwrap(); + let gates = &Gate::build(&eq); + println!("{:?}", gates); + assert_eq!(gates.len(), 2); + + // t1 = (x + 4) * 1 + assert_eq!(gates[0].a, Term::Sum(Box::new(Term::Var("x".to_string())), Box::new(Term::Num(MclFr::from(4))))); + assert_eq!(gates[0].b, Term::One); + assert_eq!(gates[0].c, Term::TmpVar(1)); + + // out = t1 * 1 + assert_eq!(gates[1].a, Term::TmpVar(1)); + assert_eq!(gates[1].b, Term::One); + assert_eq!(gates[1].c, Term::Out); + } + + #[test] + fn test_build_sub() { + MclInitializer::init(); + let input = "x - 4 == 9"; + let eq = EquationParser::parse(input).unwrap(); + let gates = &Gate::build(&eq); + assert_eq!(gates.len(), 2); + + // t1 = (x + 4) * 1 + assert_eq!(gates[0].a, Term::Sum(Box::new(Term::Num(MclFr::from(4))), Box::new(Term::TmpVar(1)))); + assert_eq!(gates[0].b, Term::One); + assert_eq!(gates[0].c, Term::Var("x".to_string())); + + // out = t1 * 1 + assert_eq!(gates[1].a, Term::TmpVar(1)); + assert_eq!(gates[1].b, Term::One); + assert_eq!(gates[1].c, Term::Out); + } + + #[test] + fn test_build_mul() { + MclInitializer::init(); + let input = "x * 4 == 9"; + let eq = EquationParser::parse(input).unwrap(); + let gates = &Gate::build(&eq); + assert_eq!(gates.len(), 2); + + // x = (4 + t1) * 1 + assert_eq!(gates[0].a, Term::Var("x".to_string())); + assert_eq!(gates[0].b, Term::Num(MclFr::from(4))); + assert_eq!(gates[0].c, Term::TmpVar(1)); + + // out = t1 * 1 + assert_eq!(gates[1].a, Term::TmpVar(1)); + assert_eq!(gates[1].c, Term::Out); + } + + #[test] + fn test_build_div() { + MclInitializer::init(); + let input = "x / 4 == 2"; + let eq = EquationParser::parse(input).unwrap(); + let gates = &Gate::build(&eq); + assert_eq!(gates.len(), 2); + + // x = 4 * t1 + assert_eq!(gates[0].a, Term::Num(MclFr::from(4))); + assert_eq!(gates[0].b, Term::TmpVar(1)); + assert_eq!(gates[0].c, Term::Var("x".to_string())); + + // out = t1 * 1 + assert_eq!(gates[1].a, Term::TmpVar(1)); + assert_eq!(gates[1].c, Term::Out); + } + + #[test] + fn test_build_combined1() { + MclInitializer::init(); + let input = "(3 * x + 4) / 2 == 11"; + let eq = EquationParser::parse(input).unwrap(); + let gates = &Gate::build(&eq); + assert_eq!(gates.len(), 4); + + // t1 = 3 * x + assert_eq!(gates[0].a, Term::Num(MclFr::from(3))); + assert_eq!(gates[0].b, Term::Var("x".to_string())); + assert_eq!(gates[0].c, Term::TmpVar(1)); + + // t2 = (t1 + 4) * 1 + assert_eq!(gates[1].a, Term::Sum( + Box::new(Term::TmpVar(1)), + Box::new(Term::Num(MclFr::from(4))) + )); + assert_eq!(gates[1].b, Term::One); + assert_eq!(gates[1].c, Term::TmpVar(2)); + + // t2 = 2 * t3 + assert_eq!(gates[2].a, Term::Num(MclFr::from(2))); + assert_eq!(gates[2].b, Term::TmpVar(3)); + assert_eq!(gates[2].c, Term::TmpVar(2)); + + // out = t3 * 1 + assert_eq!(gates[3].a, Term::TmpVar(3)); + assert_eq!(gates[3].b, Term::One); + assert_eq!(gates[3].c, Term::Out); + } + + #[test] + fn test_build_combined2() { + MclInitializer::init(); + let input = "(x * x * x) + x + 5 == 35"; + println!("Equation: {}", input); + + let eq = EquationParser::parse(input).unwrap(); + let gates = &Gate::build(&eq); + println!("Gates: {:?}", gates); + assert_eq!(gates.len(), 5); + + // t1 = x * x + assert_eq!(gates[0].a, Term::Var("x".to_string())); + assert_eq!(gates[0].b, Term::Var("x".to_string())); + assert_eq!(gates[0].c, Term::TmpVar(1)); + + // t2 = x * t1 + assert_eq!(gates[1].a, Term::Var("x".to_string())); + assert_eq!(gates[1].b, Term::TmpVar(1)); + assert_eq!(gates[1].c, Term::TmpVar(2)); + + // t3 = (x + 5) * 1 + assert_eq!(gates[2].a, Term::Sum( + Box::new(Term::Var("x".to_string())), + Box::new(Term::Num(MclFr::from(5))) + )); + assert_eq!(gates[2].b, Term::One); + assert_eq!(gates[2].c, Term::TmpVar(3)); + + // t4 = (t2 + t3) * 1 + assert_eq!(gates[3].a, Term::Sum( + Box::new(Term::TmpVar(2)), + Box::new(Term::TmpVar(3)) + )); + assert_eq!(gates[3].b, Term::One); + assert_eq!(gates[3].c, Term::TmpVar(4)); + + // out = t4 * 1 + assert_eq!(gates[4].a, Term::TmpVar(4)); + assert_eq!(gates[4].b, Term::One); + assert_eq!(gates[4].c, Term::Out); + } + + #[test] + fn blog_post_1_example_1() { + MclInitializer::init(); + let expr = "(x * x * x) + x + 5 == 35"; + let eq = EquationParser::parse(expr).unwrap(); + let gates = &Gate::build(&eq); + println!("{:?}", gates); + } +} diff --git a/src/building_block/mcl/qap/gates/adder.rs b/src/building_block/mcl/qap/gates/adder.rs new file mode 100644 index 0000000..788c7ac --- /dev/null +++ b/src/building_block/mcl/qap/gates/adder.rs @@ -0,0 +1,131 @@ +use crate::zk::w_trusted_setup::qap::gates::bool_circuit::{BoolCircuit, Processor}; + +pub struct HalfAdder(); + +#[derive(Debug)] +pub struct AdderResult { + pub sum: bool, + pub carry: bool, +} + +impl HalfAdder { + // (augend, addend) -> (sum, carry) + pub fn add(augend: bool, addend: bool) -> AdderResult { + let sum = BoolCircuit::Xor( + Box::new(BoolCircuit::Leaf(augend)), + Box::new(BoolCircuit::Leaf(addend)), + ); + let carry = BoolCircuit::And( + Box::new(BoolCircuit::Leaf(augend)), + Box::new(BoolCircuit::Leaf(addend)), + ); + + let sum = Processor::eval(&sum); + let carry = Processor::eval(&carry); + + AdderResult { sum, carry } + } +} + +pub struct FullAdder(); + +impl FullAdder { + pub fn add(augend: bool, addend: bool, carry: bool) -> AdderResult { + let res1 = HalfAdder::add(augend, addend); + let res2 = HalfAdder::add(res1.sum, carry); + let carry = BoolCircuit::Or( + Box::new(BoolCircuit::Leaf(res1.carry)), + Box::new(BoolCircuit::Leaf(res2.carry)), + ); + let carry = Processor::eval(&carry); + AdderResult { sum: res2.sum, carry } + } +} + +#[cfg(test)] +mod half_adder_tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn add_0_0() { + MclInitializer::init(); + let res = HalfAdder::add(false, false); + assert_eq!(res.sum, false); + assert_eq!(res.carry, false); + } + + #[test] + fn add_1_0_or_0_1() { + MclInitializer::init(); + let res = HalfAdder::add(true, false); + assert_eq!(res.sum, true); + assert_eq!(res.carry, false); + + let res = HalfAdder::add(false, true); + assert_eq!(res.sum, true); + assert_eq!(res.carry, false); + } + + #[test] + fn add_1_1() { + MclInitializer::init(); + let res = HalfAdder::add(true, true); + assert_eq!(res.sum, false); + assert_eq!(res.carry, true); + } +} + +#[cfg(test)] +mod full_adder_tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn single_inst_add_0_0_0() { + MclInitializer::init(); + let res = FullAdder::add(false, false, false); + assert_eq!(res.sum, false); + assert_eq!(res.carry, false); + } + + #[test] + fn single_inst_add_1_0_0_or_0_1_0_or_0_0_1() { + MclInitializer::init(); + let res = FullAdder::add(true, false, false); + assert_eq!(res.sum, true); + assert_eq!(res.carry, false); + + let res = FullAdder::add(false, true, false); + assert_eq!(res.sum, true); + assert_eq!(res.carry, false); + + let res = FullAdder::add(false, false, true); + assert_eq!(res.sum, true); + assert_eq!(res.carry, false); + } + + #[test] + fn single_inst_add_1_1_0_or_1_0_1_or_0_1_1() { + MclInitializer::init(); + let res = FullAdder::add(true, true, false); + assert_eq!(res.sum, false); + assert_eq!(res.carry, true); + + let res = FullAdder::add(true, false, true); + assert_eq!(res.sum, false); + assert_eq!(res.carry, true); + + let res = FullAdder::add(false, true, true); + assert_eq!(res.sum, false); + assert_eq!(res.carry, true); + } + + #[test] + fn single_inst_add_1_1_1() { + MclInitializer::init(); + let res = FullAdder::add(true, true, true); + assert_eq!(res.sum, true); + assert_eq!(res.carry, true); + } +} diff --git a/src/building_block/mcl/qap/gates/arith_circuit.rs b/src/building_block/mcl/qap/gates/arith_circuit.rs new file mode 100644 index 0000000..acd159d --- /dev/null +++ b/src/building_block/mcl/qap/gates/arith_circuit.rs @@ -0,0 +1,12 @@ +use crate::building_block::mcl::mcl_fr::MclFr; + +#[derive(Debug, PartialEq, Clone)] +pub enum ArithCircuit { + Leaf(MclFr), + Mul(Box, Box), + Add(Box, Box), + Sub(Box, Box), + Div(Box, Box), +} + +pub struct Processor(); diff --git a/src/building_block/mcl/qap/gates/bool_circuit.rs b/src/building_block/mcl/qap/gates/bool_circuit.rs new file mode 100644 index 0000000..58dc07e --- /dev/null +++ b/src/building_block/mcl/qap/gates/bool_circuit.rs @@ -0,0 +1,76 @@ +use crate::building_block::mcl::{ + mcl_fr::MclFr, + qap::gates::arith_circuit::ArithCircuit, +}; + +#[derive(Debug, PartialEq, Clone)] +pub enum BoolCircuit { + Leaf(bool), + And(Box, Box), + Xor(Box, Box), + Or(Box, Box), +} + +pub struct Processor(); + +impl Processor { + pub fn eval(root: &BoolCircuit) -> bool { + match root { + BoolCircuit::Leaf(x) => *x, + BoolCircuit::And(a, b) => Processor::eval(&a) && Processor::eval(&b), + BoolCircuit::Xor(a, b) => { + let a = Processor::eval(&a); + let b = Processor::eval(&b); + !(a && b) && (a || b) + } + BoolCircuit::Or(a, b) => Processor::eval(&a) || Processor::eval(&b), + } + } + + pub fn to_arith_circuit(root: BoolCircuit) -> ArithCircuit { + match root { + BoolCircuit::Leaf(x) => ArithCircuit::Leaf(MclFr::from(x)), + BoolCircuit::And(a, b) => { + let a = Processor::eval(&a); + let b = Processor::eval(&b); + let a = ArithCircuit::Leaf(MclFr::from(a)); + let b = ArithCircuit::Leaf(MclFr::from(b)); + // AND(a, b) = ab + ArithCircuit::Mul(Box::new(a), Box::new(b)) + }, + BoolCircuit::Xor(a, b) => { + let a = Processor::eval(&a); + let b = Processor::eval(&b); + let a = ArithCircuit::Leaf(MclFr::from(a)); + let b = ArithCircuit::Leaf(MclFr::from(b)); + + // XOR(a, b) = (a + b) - 2 ab + let t1 = ArithCircuit::Add( + Box::new(a.clone()), + Box::new(b.clone()), + ); + + let two = ArithCircuit::Leaf(MclFr::from(2)); + let t2 = ArithCircuit::Mul(Box::new(a), Box::new(b)); + let t2 = ArithCircuit::Mul(Box::new(two), Box::new(t2)); + ArithCircuit::Add( + Box::new(t1), + Box::new(t2), + ) + }, + BoolCircuit::Or(a, b) => { + let a = Processor::eval(&a); + let b = Processor::eval(&b); + let a = ArithCircuit::Leaf(MclFr::from(a)); + let b = ArithCircuit::Leaf(MclFr::from(b)); + // Or(a, b) = a + b - a * b + let t1 = ArithCircuit::Add(Box::new(a.clone()), Box::new(b.clone())); + let t2 = ArithCircuit::Mul(Box::new(a.clone()), Box::new(b.clone())); + ArithCircuit::Sub( + Box::new(t1), + Box::new(t2), + ) + }, + } + } +} diff --git a/src/building_block/mcl/qap/gates/mod.rs b/src/building_block/mcl/qap/gates/mod.rs new file mode 100644 index 0000000..0d0b209 --- /dev/null +++ b/src/building_block/mcl/qap/gates/mod.rs @@ -0,0 +1,4 @@ +pub mod arith_circuit; +pub mod bool_circuit; +pub mod number; +pub mod adder; \ No newline at end of file diff --git a/src/building_block/mcl/qap/gates/number.rs b/src/building_block/mcl/qap/gates/number.rs new file mode 100644 index 0000000..2329e31 --- /dev/null +++ b/src/building_block/mcl/qap/gates/number.rs @@ -0,0 +1,119 @@ +#[derive(Debug)] +pub struct Number { + pub bits: [bool; 64], +} + +impl Number { + pub fn new(n: i64) -> Self { + let mut bits = [false; 64]; + + if n == 0 { + Number { bits } + + } else { + let mut m = n; + if n < 0 { + // convert to a positive number w/ the same bit representation + m = i64::MAX + n + 1; + } + + let mut x = m; + let mut i = 0; + while x > 0 { + if x & 1 == 1 { + bits[i] = true; + } + i += 1; + x >>= 1; + } + + if n < 0 { // set sign bit if originally a negative value + bits[63] = true; + } + + Number { bits } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::mcl_initializer::MclInitializer; + + #[test] + fn zero() { + MclInitializer::init(); + let x = Number::new(0); + for i in 0..64 { + assert_eq!(x.bits[i], false); + } + } + + #[test] + fn pos_1() { + MclInitializer::init(); + let x = Number::new(1); + for i in 0..64 { + let exp = i == 0; + assert_eq!(x.bits[i], exp); + } + } + + #[test] + fn pos_5() { + MclInitializer::init(); + let x = Number::new(5); + assert_eq!(x.bits[0], true); + assert_eq!(x.bits[1], false); + assert_eq!(x.bits[2], true); + for i in 3..64 { + assert_eq!(x.bits[i], false); + } + } + + #[test] + fn pos_7() { + MclInitializer::init(); + let x = Number::new(7); + assert_eq!(x.bits[0], true); + assert_eq!(x.bits[1], true); + assert_eq!(x.bits[2], true); + for i in 3..64 { + assert_eq!(x.bits[i], false); + } + } + + #[test] + fn neg_1() { + MclInitializer::init(); + let x = Number::new(-1); + for i in 0..64 { + assert_eq!(x.bits[i], true); + } + } + + #[test] + fn neg_5() { + MclInitializer::init(); + let x = Number::new(-5); + assert_eq!(x.bits[0], true); + assert_eq!(x.bits[1], true); + assert_eq!(x.bits[2], false); + for i in 3..64 { + assert_eq!(x.bits[i], true); + } + } + + #[test] + fn neg_7() { + MclInitializer::init(); + let x = Number::new(-7); + assert_eq!(x.bits[0], true); + assert_eq!(x.bits[1], false); + assert_eq!(x.bits[2], false); + for i in 3..64 { + assert_eq!(x.bits[i], true); + } + } +} diff --git a/src/building_block/mcl/qap/mod.rs b/src/building_block/mcl/qap/mod.rs new file mode 100644 index 0000000..bd645ea --- /dev/null +++ b/src/building_block/mcl/qap/mod.rs @@ -0,0 +1,9 @@ +pub mod config; +pub mod constraint; +pub mod equation_parser; +pub mod gate; +pub mod gates; +pub mod qap; +pub mod r1cs; +pub mod r1cs_tmpl; +pub mod term; diff --git a/src/building_block/mcl/qap/qap.rs b/src/building_block/mcl/qap/qap.rs new file mode 100644 index 0000000..f7b3525 --- /dev/null +++ b/src/building_block/mcl/qap/qap.rs @@ -0,0 +1,326 @@ +use std::ops::Mul; + +use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_sparse_vec::MclSparseVec, + polynomial::{ + DivResult, + Polynomial, + }, + qap::r1cs::R1CS, +}; +use num_traits::Zero; + +#[derive(Clone)] +pub struct QAP { + pub vi: Vec, + pub wi: Vec, + pub yi: Vec, + pub num_constraints: MclFr, +} + +impl QAP { + // build a polynomial that evaluates to target_val at x == index + // and zero for x != index for each target value. + // e.g. + // (x - 2) * (x - 3) * 3 / ((1 - 2) * (1 - 3)) + // where x in [1, 2, 3]; evaluates to 3 if x == 1 and 0 if x != 1 + fn build_polynomial(target_vals: &MclSparseVec) -> Polynomial { + let mut target_val_polys = vec![]; + + let one = MclFr::from(1); + let mut target_x = MclFr::from(1); + while target_x <= target_vals.size { + let target_val = target_vals.get(&(&target_x - &one)); + + // if target val is zero, simply add 0x^0 + if target_val.is_zero() { + target_val_polys.push(Polynomial::new(&vec![MclFr::from(0)])); + target_x.inc(); + continue; + } + + let mut numerator_polys = vec![ + Polynomial::new(&vec![target_val.clone()]), + ]; + let mut denominator = MclFr::from(1); + + let mut i = MclFr::from(1); + while &i <= &target_vals.size { + if &i == &target_x { + i.inc(); + continue; + } + // (x - i) to let the polynomal evaluate to zero at x = i + let numerator_poly = Polynomial::new(&vec![ + -&i, + MclFr::from(1), + ]); + numerator_polys.push(numerator_poly); + + // (target_idx - i) to cancel out the corresponding + // numerator_poly at x = target_idx + denominator = &denominator * (&target_x - &i); + + i.inc(); + } + + // merge denominator polynomial to numerator polynomial vector + let denominator_poly = Polynomial::new(&vec![denominator.inv()]); + let mut polys = numerator_polys; + polys.push(denominator_poly); + + // aggregate numerator polynomial vector + let mut acc_poly = Polynomial::new(&vec![MclFr::from(1)]); + for poly in polys { + acc_poly = acc_poly.mul(&poly); + } + target_val_polys.push(acc_poly); + + target_x.inc(); + } + + // aggregate polynomials for all target values + let mut res = target_val_polys[0].clone(); + for x in &target_val_polys[1..] { + res += x; + } + res + } + + pub fn build_p(&self, witness: &MclSparseVec) -> Polynomial { + let zero = &Polynomial::zero(); + let (mut v, mut w, mut y) = + (zero.clone(), zero.clone(), zero.clone()); + + for i in 0..witness.size.to_usize() { + let wit = &witness[&MclFr::from(i)]; + v += &self.vi[i] * wit; + w += &self.wi[i] * wit; + y += &self.yi[i] * wit; + }; + + (v * &w) - &y + } + + // build polynomial (x-1)(x-2)..(x-num_constraints) + pub fn build_t(num_constraints: &MclFr) -> Polynomial { + let mut i = MclFr::from(1); + let mut polys = vec![]; + + // create (x-i) polynomials + while &i <= &num_constraints { + let poly = Polynomial::new(&vec![ + -&i, + MclFr::from(1), + ]); + polys.push(poly); + i.inc(); + } + // aggregate (x-i) polynomial into a single polynomial + let mut acc_poly = Polynomial::new(&vec![MclFr::from(1)]); + for poly in polys { + acc_poly = acc_poly.mul(&poly); + } + acc_poly + } + + pub fn build(r1cs: &R1CS) -> QAP { + /* + c1 c2 c3 (coeffs for a1, a2, a3) + for w=( 1, 2, 3), + at x=1, 3 * 1 (w[3] * a3) + at x=2, 2 * 3 (w[1] * a1) + at x=3, 0 + at x=4, 2 * 2 (w[2] * a2) + */ + + // w1 w2 w3 <- witness e.g. w=(x,y,z,w1,...) + // x=1 | 0 0 1 | + // x=2 | 3 0 0 | + // x=3 | 0 0 0 | + // x=4 | 0 2 0 | + // ^ + // +-- constraints + let constraints = r1cs.to_constraint_matrices(); + + // x=1 2 3 4 <- constraints + // w1 [0 3 0 0] + // w2 [0 0 0 2] + // w3 [1 0 0 0] + // ^ + // +-- witness + let constraints_v_t = constraints.a.transpose(); + let constraints_w_t = constraints.b.transpose(); + let constraints_y_t = constraints.c.transpose(); + println!("- const v_t\n{:?}", &constraints_v_t); + println!("- const w_t\n{:?}", &constraints_w_t); + println!("- const y_t\n{:?}", &constraints_y_t); + + // build polynomials for each wirness variable + // e.g. vi[0] is a polynomial for the first witness variable + // and returns 3 at x=2 and 0 at all other x's + let mut vi = vec![]; + let mut wi = vec![]; + let mut yi = vec![]; + + let mut i = MclFr::from(0); + + let num_witness_values = &r1cs.witness.size; + + while &i < num_witness_values { + // extract a constraint row + // x = 1 2 3 4 + // wi = [0 3 0 0] + let v_row = constraints_v_t.get_row(&i); + let w_row = constraints_w_t.get_row(&i); + let y_row = constraints_y_t.get_row(&i); + + // convert a constraint row to a polynomial + let v_poly = QAP::build_polynomial(&v_row); + let w_poly = QAP::build_polynomial(&w_row); + let y_poly = QAP::build_polynomial(&y_row); + + vi.push(v_poly); + wi.push(w_poly); + yi.push(y_poly); + + i.inc(); + } + + let num_constraints = constraints.a.height.clone(); + + QAP { vi, wi, yi, num_constraints } + } + + pub fn is_valid( + &self, + witness: &MclSparseVec, + num_constraints: &MclFr, + ) -> bool { + let t = QAP::build_t(num_constraints); + let p = self.build_p(witness); + + match p.divide_by(&t) { + DivResult::Quotient(_) => true, + DivResult::QuotientRemainder(_) => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_initializer::MclInitializer, + qap::constraint::Constraint, + }; + + #[test] + fn test_r1cs_to_polynomial() { + MclInitializer::init(); + // x out t1 y t2 + // 0 1 2 3 4 5 + // [1, 3, 35, 9, 27, 30] + let witness = MclSparseVec::from(&vec![ + MclFr::from(1), + MclFr::from(3), + MclFr::from(35), + MclFr::from(9), + MclFr::from(27), + MclFr::from(30), + ]); + let witness_size = &witness.size; + + // A + // 0 1 2 3 4 5 + // [0, 1, 0, 0, 0, 0] + // [0, 0, 0, 1, 0, 0] + // [0, 1, 0, 0, 1, 0] + // [5, 0, 0, 0, 0, 1] + let mut a1 = MclSparseVec::new(witness_size); + a1.set(&MclFr::from(1), &MclFr::from(1)); + + let mut a2 = MclSparseVec::new(witness_size); + a2.set(&MclFr::from(3), &MclFr::from(1)); + + let mut a3 = MclSparseVec::new(witness_size); + a3.set(&MclFr::from(1), &MclFr::from(1)); + a3.set(&MclFr::from(4), &MclFr::from(1)); + + let mut a4 = MclSparseVec::new(witness_size); + a4.set(&MclFr::from(0), &MclFr::from(5)); + a4.set(&MclFr::from(5), &MclFr::from(1)); + + // B + // 0 1 2 3 4 5 + // [0, 1, 0, 0, 0, 0] + // [0, 1, 0, 0, 0, 0] + // [1, 0, 0, 0, 0, 0] + // [1, 0, 0, 0, 0, 0] + let mut b1 = MclSparseVec::new(witness_size); + b1.set(&MclFr::from(1), &MclFr::from(1)); + + let mut b2 = MclSparseVec::new(witness_size); + b2.set(&MclFr::from(1), &MclFr::from(1)); + + let mut b3 = MclSparseVec::new(witness_size); + b3.set(&MclFr::from(0), &MclFr::from(1)); + + let mut b4 = MclSparseVec::new(witness_size); + b4.set(&MclFr::from(0), &MclFr::from(1)); + + // C + // 0 1 2 3 4 5 + // [0, 0, 0, 1, 0, 0] + // [0, 0, 0, 0, 1, 0] + // [0, 0, 0, 0, 0, 1] + // [0, 0, 1, 0, 0, 0] + let mut c1 = MclSparseVec::new(witness_size); + c1.set(&MclFr::from(3), &MclFr::from(1)); + + let mut c2 = MclSparseVec::new(witness_size); + c2.set(&MclFr::from(4), &MclFr::from(1)); + + let mut c3 = MclSparseVec::new(witness_size); + c3.set(&MclFr::from(5), &MclFr::from(1)); + + let mut c4 = MclSparseVec::new(witness_size); + c4.set(&MclFr::from(2), &MclFr::from(1)); + let constraints = vec![ + Constraint::new(&a1, &b1, &c1), + Constraint::new(&a2, &b2, &c2), + Constraint::new(&a3, &b3, &c3), + Constraint::new(&a4, &b4, &c4), + ]; + let num_constraints = &MclFr::from(constraints.len()); + let mid_beg = MclFr::from(3); + let r1cs = R1CS { + constraints, + witness: witness.clone(), + mid_beg, + }; + + let qap = QAP::build(&r1cs); + let is_passed = qap.is_valid(&witness, num_constraints); + assert!(is_passed); + } + + #[test] + fn test_build_t() { + MclInitializer::init(); + let one = &MclFr::from(1); + let two = &MclFr::from(2); + let neg_three = &-MclFr::from(3); + + // (x-1)(x-2) = x^2 - 3x + 2 + let z = QAP::build_t(two); + + // expect [2, -3, 1] + assert_eq!(z.len(), 3); + assert_eq!(&z[0], two); + assert_eq!(&z[1], neg_three); + assert_eq!(&z[2], one); + } +} diff --git a/src/building_block/mcl/qap/r1cs.rs b/src/building_block/mcl/qap/r1cs.rs new file mode 100644 index 0000000..b5cde33 --- /dev/null +++ b/src/building_block/mcl/qap/r1cs.rs @@ -0,0 +1,159 @@ +use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_sparse_vec::MclSparseVec, + mcl_sparse_matrix::MclSparseMatrix, + qap::{ + constraint::Constraint, + r1cs_tmpl::R1CSTmpl, + term::Term, + }, +}; +use std::collections::HashMap; +use num_traits::Zero; + +#[derive(Clone)] +pub struct R1CS { + pub constraints: Vec, + pub witness: MclSparseVec, + pub mid_beg: MclFr, +} + +// matrix representing a constraint whose +// row is the multiples of each witness value i.e. +// a = [a1, a2, ...] +// * +// b = [b1, b2, ...] +// || +// c = [c1, c2, ...] +#[derive(Debug)] +pub struct ConstraintMatrices { + pub a: MclSparseMatrix, + pub b: MclSparseMatrix, + pub c: MclSparseMatrix, +} + +impl R1CS { + // build the witness vector that is in the order expected by the prover + // while validating the witness vector with the witness_instance + fn build_witness_vec( + tmpl: &R1CSTmpl, + witness_instance: &HashMap, + ) -> Result<(MclSparseVec, MclFr), String> { + + // build witness vector with values assigned + let mut witness = MclSparseVec::new(&MclFr::from(tmpl.witness.len())); + + let mut i = MclFr::zero(); + + for term in tmpl.witness.iter() { + if !witness_instance.contains_key(&term) { + return Err(format!("'{:?}' is missing in witness_instance", term)); + } + witness[&i] = witness_instance.get(&term).unwrap().clone(); + i.inc(); + } + + Ok((witness, tmpl.mid_beg.clone())) + } + + // evaluate all constraints and confirm they all hold + pub fn validate(&self) -> Result<(), String> { + for constraint in &self.constraints { + let a = &(&constraint.a * &self.witness).sum(); + let b = &(&constraint.b * &self.witness).sum(); + let c = &(&constraint.c * &self.witness).sum(); + + println!("r1cs: {:?} * {:?} = {:?}", &a, &b, &c); + if &(a * b) != c { + return Err(format!("Constraint a ({:?}) * b ({:?}) = c ({:?}) doesn't hold", a, b, c)); + } + } + println!(""); + Ok(()) + } + + pub fn from_tmpl( + tmpl: &R1CSTmpl, + witness: &HashMap, + ) -> Result { + let (witness, mid_beg) = R1CS::build_witness_vec(tmpl, witness)?; + let r1cs = R1CS { + constraints: tmpl.constraints.clone(), + witness, + mid_beg, + }; + Ok(r1cs) + } + + pub fn to_constraint_by_witness_matrices(&self) -> ConstraintMatrices { + let mut a = vec![]; + let mut b = vec![]; + let mut c = vec![]; + + for constraint in &self.constraints { + a.push(&constraint.a * &self.witness); + b.push(&constraint.b * &self.witness); + c.push(&constraint.c * &self.witness); + } + + ConstraintMatrices { + a: MclSparseMatrix::from(&a), + b: MclSparseMatrix::from(&b), + c: MclSparseMatrix::from(&c), + } + } + + pub fn to_constraint_matrices(&self) -> ConstraintMatrices { + let mut a = vec![]; + let mut b = vec![]; + let mut c = vec![]; + + for constraint in &self.constraints { + a.push(constraint.a.clone()); + b.push(constraint.b.clone()); + c.push(constraint.c.clone()); + } + + ConstraintMatrices { + a: MclSparseMatrix::from(&a), + b: MclSparseMatrix::from(&b), + c: MclSparseMatrix::from(&c), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::{ + mcl_initializer::MclInitializer, + qap::{ + equation_parser::EquationParser, + gate::Gate, + r1cs_tmpl::R1CSTmpl, + term::Term, + }, + }; + + #[test] + fn test_validate() { + MclInitializer::init(); + let input = "(x + 2) + 4 * y == 21"; + let eq = EquationParser::parse(input).unwrap(); + + let gates = &Gate::build(&eq); + let tmpl = &R1CSTmpl::new(gates); + + let witness = HashMap::::from([ + (Term::One, MclFr::from(1)), + (Term::Var("x".to_string()), MclFr::from(3)), + (Term::Var("y".to_string()), MclFr::from(4)), + (Term::Out, eq.rhs), + (Term::TmpVar(1), MclFr::from(5)), + (Term::TmpVar(2), MclFr::from(16)), + (Term::TmpVar(3), MclFr::from(21)), + ]); + let r1cs = R1CS::from_tmpl(tmpl, &witness).unwrap(); + r1cs.validate().unwrap(); + } +} diff --git a/src/building_block/mcl/qap/r1cs_tmpl.rs b/src/building_block/mcl/qap/r1cs_tmpl.rs new file mode 100644 index 0000000..05ae4f9 --- /dev/null +++ b/src/building_block/mcl/qap/r1cs_tmpl.rs @@ -0,0 +1,421 @@ +use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_sparse_vec::MclSparseVec, + qap::{ + term::Term, + gate::Gate, + constraint::Constraint, + }, +}; +use std::collections::HashMap; +use num_traits::Zero; + +pub struct R1CSTmpl { + pub constraints: Vec, + pub witness: Vec, + pub indices: HashMap, + pub mid_beg: MclFr, +} + +impl R1CSTmpl { + // build witness vector whose elements in the following order: + // 1, inputs, Out, mid + fn build_witness( + inputs: &Vec, + mid: &Vec, + witness: &mut Vec, + indices: &mut HashMap::, + ) -> MclFr { + let mut i = MclFr::from(1); // `1` has already been added in new function + + for x in inputs { + witness.push(x.clone()); + indices.insert(x.clone(), i.clone()); + i.inc(); + } + witness.push(Term::Out); + indices.insert(Term::Out, i.clone()); + i.inc(); + + let mid_beg = i.clone(); + + for x in mid { + witness.push(x.clone()); + indices.insert(x.clone(), i.clone()); + i.inc(); + } + + mid_beg + } + + fn categorize_witness_terms( + t: &Term, + inputs: &mut Vec, + mid: &mut Vec, + ) { + match t { + Term::One => (), // not categorized as inputs or mid + Term::Num(_) => (), // Num is represented as multiple of Term::One, so not adding to witness + Term::Out => (), // not categorized as inputs or mid + Term::Var(_) => if !inputs.contains(&t) { inputs.push(t.clone()) }, + Term::TmpVar(_) => if !mid.contains(&t) { mid.push(t.clone()) }, + Term::Sum(a, b) => { + R1CSTmpl::categorize_witness_terms(&a, inputs, mid); + R1CSTmpl::categorize_witness_terms(&b, inputs, mid); + }, + } + } + + fn build_constraint_vec( + vec: &mut MclSparseVec, + term: &Term, + indices: &HashMap::, + ) { + match term { + Term::Sum(a, b) => { + R1CSTmpl::build_constraint_vec(vec, &a, indices); + R1CSTmpl::build_constraint_vec(vec, &b, indices); + }, + Term::Num(n) => { + vec.set(&MclFr::zero(), n); // Num is represented as Term::One at index 0 times n + }, + x => { + let index = indices.get(&x).unwrap(); + vec.set(index, &MclFr::from(1)); + }, + } + } + + pub fn new(gates: &[Gate]) -> Self { + let mut witness = vec![]; + let mut indices = HashMap::::new(); + + // add `1` at index 0 + witness.push(Term::One); + indices.insert(Term::One, MclFr::zero()); + + // categoraize terms contained in gates to inputs and mid + let mut inputs = vec![]; + let mut mid = vec![]; + + for gate in gates { + R1CSTmpl::categorize_witness_terms(&gate.a, &mut inputs, &mut mid); + R1CSTmpl::categorize_witness_terms(&gate.b, &mut inputs, &mut mid); + R1CSTmpl::categorize_witness_terms(&gate.c, &mut inputs, &mut mid); + } + + let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices); + let vec_size = &MclFr::from(witness.len()); + let mut constraints = vec![]; + + // create a, b anc c vectors for each gate + for gate in gates { + let mut a = MclSparseVec::new(vec_size); + R1CSTmpl::build_constraint_vec(&mut a, &gate.a, &indices); + + let mut b = MclSparseVec::new(vec_size); + R1CSTmpl::build_constraint_vec(&mut b, &gate.b, &indices); + + let mut c = MclSparseVec::new(vec_size); + R1CSTmpl::build_constraint_vec(&mut c, &gate.c, &indices); + + let constraint = Constraint { a, b, c }; + constraints.push(constraint) + } + + R1CSTmpl { + constraints, + witness, + indices, + mid_beg, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::building_block::mcl::{ + qap::equation_parser::EquationParser, + mcl_initializer::MclInitializer, + }; + + #[test] + fn test_categorize_witness_terms() { + MclInitializer::init(); + // Num term should not be categorized as input or mid + { + let mut inputs = vec![]; + let mut mid = vec![]; + let term = &Term::Num(MclFr::from(9)); + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + assert_eq!(inputs.len(), 0); + assert_eq!(mid.len(), 0); + } + + // One term should not be categorized as input or mid + { + let mut inputs = vec![]; + let mut mid = vec![]; + let term = &Term::One; + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + assert_eq!(inputs.len(), 0); + assert_eq!(mid.len(), 0); + } + + // Var term should be categorized as input + { + let mut inputs = vec![]; + let mut mid = vec![]; + let term = &Term::Var("x".to_string()); + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + assert_eq!(inputs.len(), 1); + assert_eq!(mid.len(), 0); + } + + // Out term should be not categorized as input or mid + { + let mut inputs = vec![]; + let mut mid = vec![]; + let term = &Term::Out; + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + assert_eq!(inputs.len(), 0); + assert_eq!(mid.len(), 0); + } + + // TmpVar term should be categorized as mid + { + let mut inputs = vec![]; + let mut mid = vec![]; + let term = &Term::TmpVar(1); + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + assert_eq!(inputs.len(), 0); + assert_eq!(mid.len(), 1); + } + + // Sum term should be recursively categorized + { + let mut inputs = vec![]; + let mut mid = vec![]; + let y = Term::Var("y".to_string()); + let z = Term::Var("z".to_string()); + let term = &Term::Sum(Box::new(y.clone()), Box::new(z.clone())); + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + assert_eq!(inputs.len(), 2); + assert_eq!(mid.len(), 0); + } + } + + #[test] + fn test_build_witness() { + MclInitializer::init(); + let a = Term::Var("a".to_string()); + let b = Term::Var("b".to_string()); + let sum = Term::Sum(Box::new(a), Box::new(b)); + + let terms = vec![ + Term::Num(MclFr::from(9)), // Num should be ignored + Term::One, // One is discarded in categorize_witness_terms + Term::Var("x".to_string()), // Var should be added + Term::Var("x".to_string()), // the same Var should be added twice + Term::Var("y".to_string()), // different Var should be added + Term::TmpVar(1), // TmpVar should be added + Term::TmpVar(1), // same TmpVar should not be added twice + Term::TmpVar(2), // different TmpVar should be added + Term::Out, // Out should be added + Term::Out, // Out should not be added twice + sum, // sum should be added recursively + ]; + + let mut inputs = vec![]; + let mut mid = vec![]; + + for term in &terms { + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + } + + let mut witness = vec![]; + let mut indices = HashMap::::new(); + + let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices); + assert!(mid_beg == MclFr::from(6)); + + // 7 since One has been discarded and Out is added in build_witness + assert_eq!(indices.len(), 7); + assert_eq!(witness.len(), 7); + + // check if witness is correctly built + let exp = vec![ + // One has been discarded + Term::Var("x".to_string()), + Term::Var("y".to_string()), + Term::Var("a".to_string()), + Term::Var("b".to_string()), + Term::Out, // build_witness adds Out + Term::TmpVar(1), + Term::TmpVar(2), + ]; + assert!(witness == exp); + + // check if indices map is correctly built + assert!(indices.get(&Term::One).is_none()); + assert!(indices.get(&Term::Var("x".to_string())).unwrap() == &MclFr::from(1)); + assert!(indices.get(&Term::Var("y".to_string())).unwrap() == &MclFr::from(2)); + assert!(indices.get(&Term::Var("a".to_string())).unwrap() == &MclFr::from(3)); + assert!(indices.get(&Term::Var("b".to_string())).unwrap() == &MclFr::from(4)); + assert!(indices.get(&Term::Out).unwrap() == &MclFr::from(5)); + assert!(indices.get(&Term::TmpVar(1)).unwrap() == &MclFr::from(6)); + assert!(indices.get(&Term::TmpVar(2)).unwrap() == &MclFr::from(7)); + } + + #[test] + fn test_new() { + MclInitializer::init(); + let gates = vec![]; + let tmpl = R1CSTmpl::new( &gates); + assert_eq!(tmpl.indices.len(), 2); + + // if gates is empty, witness should contain only One term and Out term + assert_eq!(tmpl.indices.get(&Term::One).unwrap(), &MclFr::from(0)); + assert_eq!(tmpl.indices.get(&Term::Out).unwrap(), &MclFr::from(1)); + assert_eq!(tmpl.witness.len(), 2); + assert_eq!(tmpl.witness[0], Term::One); + assert_eq!(tmpl.witness[1], Term::Out); + } + + #[test] + fn test_constraint_generation() { + MclInitializer::init(); + { + // Num + let mut inputs = vec![]; + let mut mid = vec![]; + let term = &Term::Num(MclFr::from(4)); + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + + let mut witness = vec![]; + let mut indices = HashMap::::new(); + let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices); + assert!(mid_beg == MclFr::from(2)); + + let mut constraint = MclSparseVec::new(&MclFr::from(2)); + R1CSTmpl::build_constraint_vec(&mut constraint, &term, &indices); + + // should be mapped to One term at index 0 + assert_eq!(constraint.get(&MclFr::zero()), &MclFr::from(4)); + } + { + // Sum + let mut inputs = vec![]; + let mut mid = vec![]; + + let y = Term::Var("y".to_string()); + let z = Term::Var("z".to_string()); + let term = &Term::Sum(Box::new(y.clone()), Box::new(z.clone())); + R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid); + + let mut witness = vec![]; + let mut indices = HashMap::::new(); + let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices); + assert!(mid_beg == MclFr::from(4)); + + let mut constraint = MclSparseVec::new(&MclFr::from(3)); + R1CSTmpl::build_constraint_vec(&mut constraint, &term, &indices); + + // y and z should be stored at index 1 and 2 of witness vector respectively + assert_eq!(constraint.get(&MclFr::from(1)), &MclFr::from(1)); + assert_eq!(constraint.get(&MclFr::from(1)), &MclFr::from(1)); + } + } + + + #[test] + fn test_witness_indices() { + MclInitializer::init(); + let input = "(3 * x + 4) / 2 == 11"; + let eq = EquationParser::parse(input).unwrap(); + + let gates = &Gate::build(&eq); + let tmpl = R1CSTmpl::new(gates); + + let h = tmpl.indices; + let w = [ + Term::One, + Term::Var("x".to_string()), + Term::Out, + Term::TmpVar(1), + Term::TmpVar(2), + Term::TmpVar(3), + ]; + assert_eq!(h.len(), w.len()); + + for (i, term) in w.iter().enumerate() { + assert_eq!(h.get(&term).unwrap(), &MclFr::from(i)); + } + } + + fn term_to_str(tmpl: &R1CSTmpl, vec: &MclSparseVec) -> String { + MclInitializer::init(); + let mut indices = vec.indices().to_vec(); + indices.sort(); // sort to make indices order deterministic + let s = indices.iter().map(|i| { + let i_usize: usize = i.to_usize(); + match &tmpl.witness[i_usize] { + Term::Var(s) => s.clone(), + Term::TmpVar(i) => format!("t{}", i), + Term::One => format!("{:?}", &vec.get(i)), + Term::Out => "out".to_string(), + // currently not handling Term::Sum since it's not used in tests + _ => "?".to_string(), + } + }).collect::>().join(" + "); + format!("{}", s) + } + + #[test] + fn test_r1cs_build_a_b_c_matrix() { + MclInitializer::init(); + let input = "3 * x + 4 == 11"; + let eq = EquationParser::parse(input).unwrap(); + + let gates = &Gate::build(&eq); + let tmpl = R1CSTmpl::new(gates); + + let mut res = vec![]; + for constraint in &tmpl.constraints { + let a = term_to_str(&tmpl, &constraint.a); + let b = term_to_str(&tmpl, &constraint.b); + let c = term_to_str(&tmpl, &constraint.c); + res.push((a, b, c)); + } + + assert_eq!(res.len(), 3); + assert_eq!(res[0], ("3".to_string(), "x".to_string(), "t1".to_string())); + assert_eq!(res[1], ("4 + t1".to_string(), "1".to_string(), "t2".to_string())); + assert_eq!(res[2], ("t2".to_string(), "1".to_string(), "out".to_string())); + } + + #[test] + fn blog_post_1_example_1() { + MclInitializer::init(); + let expr = "(x * x * x) + x + 5 == 35"; + let eq = EquationParser::parse(expr).unwrap(); + let gates = &Gate::build(&eq); + let r1cs_tmpl = R1CSTmpl::new(gates); + + println!("{:?}", r1cs_tmpl.witness); + } + + #[test] + fn blog_post_1_example_2() { + MclInitializer::init(); + let expr = "(x * x * x) + x + 5 == 35"; + let eq = EquationParser::parse(expr).unwrap(); + let gates = &Gate::build(&eq); + let r1cs_tmpl = R1CSTmpl::new(gates); + + println!("w = {:?}", r1cs_tmpl.witness); + println!("{:?}", r1cs_tmpl.constraints); + } +} + diff --git a/src/building_block/mcl/qap/term.rs b/src/building_block/mcl/qap/term.rs new file mode 100644 index 0000000..cd79b4d --- /dev/null +++ b/src/building_block/mcl/qap/term.rs @@ -0,0 +1,33 @@ +use crate::building_block::mcl::{ + mcl_fr::MclFr, + qap::config::SignalId, +}; + +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum Term { + Num(MclFr), + One, + Out, + Sum(Box, Box), // Sum will not contain Out and Sum itself + TmpVar(SignalId), + Var(String), +} + +impl Term { + pub fn var(name: &str) -> Term { + Term::Var(name.to_string()) + } +} + +impl std::fmt::Debug for Term { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Term::Num(n) => write!(f, "{:?}", n), + Term::One => write!(f, "1"), + Term::Out => write!(f, "out"), + Term::Sum(a, b) => write!(f, "({:?} + {:?})", a, b), + Term::TmpVar(n) => write!(f, "t{:?}", n), + Term::Var(s) => write!(f, "{:?}", s), + } + } +} diff --git a/src/building_block/mod.rs b/src/building_block/mod.rs index 38021a6..8bfaa73 100644 --- a/src/building_block/mod.rs +++ b/src/building_block/mod.rs @@ -1,6 +1,7 @@ // pub mod elliptic_curve; pub mod field; pub mod hasher; +pub mod mcl; pub mod random_number; pub mod curves; pub mod to_bigint; diff --git a/src/zk/w_trusted_setup/groth16_mcl/crs.rs b/src/zk/w_trusted_setup/groth16_mcl/crs.rs new file mode 100644 index 0000000..8230aee --- /dev/null +++ b/src/zk/w_trusted_setup/groth16_mcl/crs.rs @@ -0,0 +1,138 @@ +use crate::{ + building_block::mcl::{ + mcl_fr::MclFr, + mcl_g1::MclG1, + mcl_g2::MclG2, + mcl_gt::MclGT, + pairing::Pairing, + qap::qap::QAP, + }, + zk::w_trusted_setup::groth16_mcl::prover::Prover, +}; + +pub struct G1 { + pub alpha: MclG1, + pub beta: MclG1, + pub delta: MclG1, + pub xi: Vec, // x powers + pub uvw_stmt: Vec, // beta*u(x) + alpha*v(x) + w(x) / div (statement) + pub uvw_wit: Vec, // beta*u(x) + alpha*v(x) + w(x) / div (witness) + pub ht_by_delta: MclG1, // h(x) * t(x) / delta +} + +pub struct G2 { + pub beta: MclG2, + pub gamma: MclG2, + pub delta: MclG2, + pub xi: Vec, // x powers +} + +pub struct GT { + pub alpha_beta: MclGT, +} + +#[allow(non_snake_case)] +pub struct CRS { + pub g1: G1, + pub g2: G2, + pub gt: GT, +} + +impl CRS { + // 0, 1, .., l, l+1, .., m + // +---------+ +--------+ + // statement witness + pub fn new( + prover: &Prover, + pairing: &Pairing, + ) -> Self { + println!("--> Building sigma..."); + let g = &MclG1::g(); + let h = &MclG2::g(); + + // sample random non-zero field element + let alpha = &MclFr::rand(true); + let beta = &MclFr::rand(true); + let gamma = &MclFr::rand(true); + let delta = &MclFr::rand(true); + let x = &MclFr::rand(true); + + macro_rules! calc_uvw_div { + ($from: expr, $to: expr, $div_factor: expr) => { + { + let mut ys: Vec = vec![]; + let mut i = $from.clone(); + + while &i <= $to { + let ui = beta * &prover.ui[i].eval_at(x); + let vi = alpha * &prover.vi[i].eval_at(x); + let wi = &prover.wi[i].eval_at(x); + let y = (ui + vi + wi) * $div_factor; + + ys.push(g * y); + i += 1; + } + ys + } + } + } + + let uvw_stmt = calc_uvw_div!(0, &prover.l.to_usize(), &gamma.inv()); + let uvw_wit = calc_uvw_div!(&prover.l.to_usize() + 1, &prover.m.to_usize(), &delta.inv()); + + macro_rules! calc_n_pows { + ($point_type: ty, $x: expr) => { + { + let generator = &<$point_type>::g(); + let mut ys: Vec<$point_type> = vec![]; + let mut x_pow = MclFr::from(1); + + for _ in 0..prover.n.to_usize(){ + ys.push(generator * &x_pow); + x_pow = x_pow * x; + } + ys + } + } + } + + let xi_g1 = calc_n_pows!(MclG1, x); + + let ht_by_delta = { + let h = &prover.h.eval_at(x); + let t = &QAP::build_t(&prover.n).eval_at(x); + let v = h * t * &delta.inv(); + MclG1::g() * &v + }; + + let g1 = G1 { + alpha: g * alpha, + beta: g * beta, + delta: g * delta, + xi: xi_g1, + uvw_stmt, + uvw_wit, + ht_by_delta, + }; + + let xi_g2 = calc_n_pows!(MclG2, x); + + let g2 = G2 { + beta: h * beta, + gamma: h * gamma, + delta: h * delta, + xi: xi_g2, + }; + + let gt = GT { + alpha_beta: pairing.e(&g1.alpha, &g2.beta), + }; + + CRS { + g1, + g2, + gt, + } + } +} + diff --git a/src/zk/w_trusted_setup/groth16_mcl/mod.rs b/src/zk/w_trusted_setup/groth16_mcl/mod.rs new file mode 100644 index 0000000..a9cdf45 --- /dev/null +++ b/src/zk/w_trusted_setup/groth16_mcl/mod.rs @@ -0,0 +1,5 @@ +pub mod crs; +pub mod proof; +pub mod prover; +pub mod verifier; +pub mod wires; diff --git a/src/zk/w_trusted_setup/groth16_mcl/proof.rs b/src/zk/w_trusted_setup/groth16_mcl/proof.rs new file mode 100644 index 0000000..95acc3c --- /dev/null +++ b/src/zk/w_trusted_setup/groth16_mcl/proof.rs @@ -0,0 +1,12 @@ +use crate::building_block::mcl::{ + mcl_g1::MclG1, + mcl_g2::MclG2, +}; + +#[allow(non_snake_case)] +pub struct Proof { + pub A: MclG1, + pub B: MclG2, + pub C: MclG1, +} + diff --git a/src/zk/w_trusted_setup/groth16_mcl/prover.rs b/src/zk/w_trusted_setup/groth16_mcl/prover.rs new file mode 100644 index 0000000..1d2ffb7 --- /dev/null +++ b/src/zk/w_trusted_setup/groth16_mcl/prover.rs @@ -0,0 +1,179 @@ +use crate::{ + building_block::mcl::{ + mcl_fr::MclFr, + mcl_g1::MclG1, + mcl_g2::MclG2, + polynomial::{ + DivResult, + Polynomial, + }, + qap::{ + equation_parser::EquationParser, + gate::Gate, + qap::QAP, + r1cs::R1CS, + r1cs_tmpl::R1CSTmpl, + term::Term, + }, + }, + zk::w_trusted_setup::groth16_mcl::{ + crs::CRS, + proof::Proof, + wires::Wires, + }, +}; +use num_traits::Zero; +use std::collections::HashMap; + +pub struct Prover { + pub n: MclFr, // # of constraints + pub l: MclFr, // end index of statement variables + pub m: MclFr, // end index of statement + witness variables + pub wires: Wires, + pub h: Polynomial, + pub t: Polynomial, + pub ui: Vec, + pub vi: Vec, + pub wi: Vec, +} + +impl Prover { + pub fn new( + expr: &str, + witness_map: &HashMap, + ) -> Self { + let eq = EquationParser::parse(expr).unwrap(); + + let gates = &Gate::build(&eq); + let tmpl = &R1CSTmpl::new(gates); + + let r1cs = R1CS::from_tmpl(tmpl, &witness_map).unwrap(); + r1cs.validate().unwrap(); + + let qap = QAP::build(&r1cs); + + let t = QAP::build_t(&MclFr::from(tmpl.constraints.len())); + let h = { + let p = qap.build_p(&r1cs.witness); + match p.divide_by(&t) { + DivResult::Quotient(q) => q, + _ => panic!("p should be divisible by t"), + } + }; + + let l = &tmpl.mid_beg - MclFr::from(1); + let m = MclFr::from(tmpl.witness.len() - 1); + let wires = Wires::new(&r1cs.witness.clone(), &l); + let n = MclFr::from(tmpl.constraints.len()); + + Prover { + n, + l, + m, + wires, + t, + h, + ui: qap.vi.clone(), + vi: qap.wi.clone(), + wi: qap.yi.clone(), + } + } + + #[allow(non_snake_case)] + pub fn prove(&self, crs: &CRS) -> Proof { + println!("--> Generating proof..."); + let r = &MclFr::rand(true); + let s = &MclFr::rand(true); + + let (A, B, B_g1) = { + let mut sum_term_A = MclG1::zero(); + let mut sum_term_B = MclG2::zero(); + let mut sum_term_B_g1 = MclG1::zero(); + + for i in 0..=self.m.to_usize() { + let ai = &self.wires[i]; + let ui_prod = self.ui[i].eval_with_g1_hidings(&crs.g1.xi) * ai; + let vi_prod = self.vi[i].eval_with_g2_hidings(&crs.g2.xi) * ai; + let vi_prod_g1 = self.vi[i].eval_with_g1_hidings(&crs.g1.xi) * ai; + + sum_term_A += ui_prod; + sum_term_B += vi_prod; + sum_term_B_g1 += vi_prod_g1; + } + let A = &crs.g1.alpha + &sum_term_A + &crs.g1.delta * r; + let B = &crs.g2.beta + &sum_term_B + &crs.g2.delta * s; + let B_g1 = &crs.g1.beta + &sum_term_B_g1 + &crs.g1.delta * s; + (A, B, B_g1) + }; + + let C = { + let mut sum = MclG1::zero(); + + let wit_beg = self.l.to_usize() + 1; + for i in wit_beg..=self.m.to_usize() { + let ai = &self.wires[i]; + sum += &crs.g1.uvw_wit[i - wit_beg] * ai; + } + sum + + &crs.g1.ht_by_delta + + &A * s + + &B_g1 * r + + -(&crs.g1.delta * r * s) + }; + + Proof { + A, + B, + C, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + building_block::mcl::{ + pairing::Pairing, + mcl_initializer::MclInitializer, + }, + zk::w_trusted_setup::groth16_mcl::verifier::Verifier, + }; + + #[test] + fn test_generate_proof_and_verify() { + MclInitializer::init(); + + let expr = "(x * x * x) + x + 5 == 35"; + println!("Expr: {}\n", expr); + let eq = EquationParser::parse(expr).unwrap(); + + let witness_map = { + use crate::building_block::mcl::qap::term::Term::*; + HashMap::::from([ + (Term::One, MclFr::from(1)), + (Term::var("x"), MclFr::from(3)), + (TmpVar(1), MclFr::from(9)), + (TmpVar(2), MclFr::from(27)), + (TmpVar(3), MclFr::from(8)), + (TmpVar(4), MclFr::from(35)), + (Out, eq.rhs), + ]) + }; + let prover = &Prover::new(expr, &witness_map); + let pairing = &Pairing; + let verifier = &Verifier::new(pairing); + let crs = CRS::new(prover, pairing); + + let proof = prover.prove(&crs); + let stmt_wires = &prover.wires.statement(); + let result = verifier.verify( + &proof, + &crs, + stmt_wires, + ); + + assert!(result); + } +} + diff --git a/src/zk/w_trusted_setup/groth16_mcl/verifier.rs b/src/zk/w_trusted_setup/groth16_mcl/verifier.rs new file mode 100644 index 0000000..ec9a873 --- /dev/null +++ b/src/zk/w_trusted_setup/groth16_mcl/verifier.rs @@ -0,0 +1,55 @@ +// Implementation of protocol 2 described on page 5 in https://eprint.iacr.org/2013/279.pdf + +use crate::{ + building_block::mcl::{ + mcl_fr::MclFr, + mcl_g1::MclG1, + mcl_g2::MclG2, + mcl_sparse_vec::MclSparseVec, + pairing::Pairing, + }, + zk::w_trusted_setup::groth16_mcl::{ + crs::CRS, + proof::Proof, + }, +}; +use num_traits::Zero; + +pub struct Verifier { + pairing: Pairing, +} + +impl Verifier { + pub fn new(pairing: &Pairing) -> Self { + Verifier { + pairing: pairing.clone(), + } + } + + pub fn verify( + &self, + proof: &Proof, + crs: &CRS, + stmt_wires: &MclSparseVec, + ) -> bool { + let e = |a: &MclG1, b: &MclG2| self.pairing.e(a, b); + + println!("--> Verifying Groth16 proof..."); + let lhs = e(&proof.A, &proof.B); + + let mut sum_term = MclG1::zero(); + for i in 0..stmt_wires.size.to_usize() { + let ai = &stmt_wires[&MclFr::from(i)]; + sum_term += &crs.g1.uvw_stmt[i] * ai; + } + + let rhs = + &crs.gt.alpha_beta + * e(&sum_term, &crs.g2.gamma) + * e(&proof.C, &crs.g2.delta) + ; + + lhs == rhs + } +} + diff --git a/src/zk/w_trusted_setup/groth16_mcl/wires.rs b/src/zk/w_trusted_setup/groth16_mcl/wires.rs new file mode 100644 index 0000000..a6ed951 --- /dev/null +++ b/src/zk/w_trusted_setup/groth16_mcl/wires.rs @@ -0,0 +1,78 @@ +use crate::building_block::mcl::{ + mcl_fr::MclFr, + mcl_sparse_vec::MclSparseVec, +}; +use core::ops::Index; +use num_traits::Zero; + +// wires: +// 0, 1, .., l, l+1, .., m +// +---------+ +--------+ +// statement witness +pub struct Wires { + sv: MclSparseVec, + witness_beg: MclFr, +} + +impl Wires { + // l is index of the last statement wire + pub fn new(sv: &MclSparseVec, l: &MclFr) -> Self { + Wires { + sv: sv.clone(), + witness_beg: l + MclFr::from(1), + } + } + + pub fn statement(&self) -> MclSparseVec { + self.sv.slice(&MclFr::zero(), &self.witness_beg) + } + + pub fn witness(&self) -> MclSparseVec { + let from = &self.witness_beg; + let to = &self.sv.size; + self.sv.slice(from, to) + } +} + +impl Index for Wires { + type Output = MclFr; + + fn index(&self, index: usize) -> &Self::Output { + &self.sv[&MclFr::from(index)] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_wire_indices() { + // [1,3,35,9,27,8,35] + let mut sv = MclSparseVec::new(&MclFr::from(7)); + sv[&MclFr::from(0)] = MclFr::from(1); + sv[&MclFr::from(1)] = MclFr::from(3); + sv[&MclFr::from(2)] = MclFr::from(35); // <-- l + sv[&MclFr::from(3)] = MclFr::from(9); + sv[&MclFr::from(4)] = MclFr::from(27); + sv[&MclFr::from(5)] = MclFr::from(8); + sv[&MclFr::from(6)] = MclFr::from(35); + + let w = Wires::new(&sv, &MclFr::from(2)); + + let st = &w.statement(); + assert!(st.size == MclFr::from(3)); + assert!(st[&MclFr::from(0)] == MclFr::from(1)); + assert!(st[&MclFr::from(1)] == MclFr::from(3)); + assert!(st[&MclFr::from(2)] == MclFr::from(35)); + + let wit = &w.witness(); + assert!(wit.size == MclFr::from(4)); + assert!(wit[&MclFr::from(0)] == MclFr::from(9)); + assert!(wit[&MclFr::from(1)] == MclFr::from(27)); + assert!(wit[&MclFr::from(2)] == MclFr::from(8)); + assert!(wit[&MclFr::from(3)] == MclFr::from(35)); + } +} + + diff --git a/src/zk/w_trusted_setup/mod.rs b/src/zk/w_trusted_setup/mod.rs index b476566..63b5402 100644 --- a/src/zk/w_trusted_setup/mod.rs +++ b/src/zk/w_trusted_setup/mod.rs @@ -1,3 +1,4 @@ pub mod groth16; +pub mod groth16_mcl; pub mod pinocchio; pub mod qap;