add arith modules based on mcl. add mcl-based groth16

This commit is contained in:
exfinen
2023-11-09 18:32:16 +09:00
parent b1485e8a27
commit 49b8475517
35 changed files with 6484 additions and 4 deletions

18
Cargo.lock generated
View File

@@ -20,6 +20,15 @@ dependencies = [
"wyz",
]
[[package]]
name = "cc"
version = "1.0.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0"
dependencies = [
"libc",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
@@ -55,6 +64,14 @@ version = "0.2.120"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad5c14e80759d0939d013e6ca49930e59fc53dd8e5009132f76240c179380c09"
[[package]]
name = "mcl_rust"
version = "0.0.1"
source = "git+https://github.com/herumi/mcl-rust.git#50056e4f1cd7525042887cbcfe9aecaf8c4858da"
dependencies = [
"cc",
]
[[package]]
name = "memchr"
version = "2.5.0"
@@ -183,6 +200,7 @@ version = "0.1.0"
dependencies = [
"bitvec",
"hex",
"mcl_rust",
"nom",
"num-bigint",
"num-traits",

View File

@@ -12,4 +12,4 @@ num-traits = "0.2"
once_cell = "1.18.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
mcl_rust = { git = "https://github.com/herumi/mcl-rust.git" }

View File

@@ -6,9 +6,9 @@ To implement cryptographic primitives in the simplest form without using any opt
## What's implemented so far
- Groth16 zk-SNARK
- Proof generation and verification
- An implementation fully based on zk-toolkit
- An implementation utilizing BLS12-381 curve of external mcl library
- Pinnochio zk-SNARK (protocol 2)
- Proof generation and verification
- Common zk-SNARK compnents
- Equation parser
- R1CS
@@ -36,4 +36,3 @@ To implement cryptographic primitives in the simplest form without using any opt
## What's NOT implemented so far
- Arbitrary-precision unsigned integer
- Random number generator

View File

@@ -0,0 +1,370 @@
use mcl_rust::*;
use std::{
cmp::Ordering,
convert::From,
fmt,
ops::{Add, Sub, Mul, Neg},
hash::{Hash, Hasher},
};
use num_traits::Zero;
#[derive(Clone)]
pub struct MclFr {
pub v: Fr,
}
impl MclFr {
pub fn new() -> Self {
let v = Fr::zero();
MclFr::from(&v)
}
pub fn inv(&self) -> Self {
let mut v = Fr::zero();
Fr::inv(&mut v, &self.v);
MclFr::from(&v)
}
pub fn sq(&self) -> Self {
let mut v = Fr::zero();
Fr::sqr(&mut v, &self.v);
MclFr::from(&v)
}
pub fn rand(exclude_zero: bool) -> Self {
let mut v = Fr::zero();
while {
Fr::set_by_csprng(&mut v);
v.is_zero() && exclude_zero
} {}
MclFr::from(&v)
}
pub fn inc(&mut self) {
let v = &self.v + &Fr::from_int(1);
self.v = v;
}
pub fn to_usize(&self) -> usize {
self.v.get_str(10).parse().unwrap()
}
}
impl Zero for MclFr {
fn is_zero(&self) -> bool {
self.v.is_zero()
}
fn zero() -> Self {
MclFr { v: Fr::zero() }
}
}
impl From<i32> for MclFr {
fn from(value: i32) -> Self {
let v = Fr::from_int(value);
MclFr { v }
}
}
impl From<usize> for MclFr {
fn from(value: usize) -> Self {
let value: i32 = value as i32;
let v = Fr::from_int(value);
MclFr { v }
}
}
impl From<&Fr> for MclFr {
fn from(v: &Fr) -> Self {
MclFr { v: v.clone() }
}
}
impl From<&str> for MclFr {
fn from(s: &str) -> Self {
let mut v = Fr::zero();
Fr::set_str(&mut v, s, 10);
MclFr { v }
}
}
impl From<bool> for MclFr {
fn from(b: bool) -> Self {
let v = {
if b {
Fr::from_int(1)
} else {
Fr::zero()
}
};
MclFr { v }
}
}
impl Ord for MclFr {
fn cmp(&self, other: &Self) -> Ordering {
let r = &self.v.cmp(&other.v);
if r.is_zero() {
Ordering::Equal
} else if r < &0 {
Ordering::Less
} else {
Ordering::Greater
}
}
}
impl PartialOrd for MclFr {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for MclFr {
fn eq(&self, other: &Self) -> bool {
self.v == other.v
}
}
impl Hash for MclFr {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut buf: Vec<u8> = vec![];
let mut v = Fr::zero();
v.set_hash_of(&mut buf);
buf.hash(state);
}
}
impl Eq for MclFr {}
impl fmt::Debug for MclFr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.v.get_str(10))
}
}
impl fmt::Display for MclFr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.v.get_str(10))
}
}
macro_rules! impl_neg {
($target: ty) => {
impl Neg for $target {
type Output = MclFr;
fn neg(self) -> Self::Output {
let mut v = Fr::zero();
Fr::neg(&mut v, &self.v);
MclFr { v }
}
}
}
}
impl_neg!(MclFr);
impl_neg!(&MclFr);
macro_rules! impl_add {
($rhs: ty, $target: ty) => {
impl Add<$rhs> for $target {
type Output = MclFr;
fn add(self, rhs: $rhs) -> Self::Output {
let mut v = Fr::zero();
Fr::add(&mut v, &self.v, &rhs.v);
MclFr { v }
}
}
};
}
impl_add!(MclFr, MclFr);
impl_add!(&MclFr, MclFr);
impl_add!(MclFr, &MclFr);
impl_add!(&MclFr, &MclFr);
macro_rules! impl_sub {
($rhs: ty, $target: ty) => {
impl Sub<$rhs> for $target {
type Output = MclFr;
fn sub(self, rhs: $rhs) -> Self::Output {
let mut v = Fr::zero();
Fr::sub(&mut v, &self.v, &rhs.v);
MclFr { v }
}
}
};
}
impl_sub!(MclFr, MclFr);
impl_sub!(&MclFr, MclFr);
impl_sub!(MclFr, &MclFr);
impl_sub!(&MclFr, &MclFr);
macro_rules! impl_mul {
($rhs: ty, $target: ty) => {
impl Mul<$rhs> for $target {
type Output = MclFr;
fn mul(self, rhs: $rhs) -> Self::Output {
let mut v = Fr::zero();
Fr::mul(&mut v, &self.v, &rhs.v);
MclFr { v }
}
}
};
}
impl_mul!(MclFr, MclFr);
impl_mul!(&MclFr, MclFr);
impl_mul!(MclFr, &MclFr);
impl_mul!(&MclFr, &MclFr);
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn test_add() {
MclInitializer::init();
let n3 = MclFr::from(3);
let n9 = MclFr::from(9);
let exp = MclFr::from(12);
let act = n3 + n9;
assert_eq!(exp, act);
}
#[test]
fn test_sub() {
MclInitializer::init();
let n9 = MclFr::from(9);
let n3 = MclFr::from(3);
let exp = MclFr::from(6);
let act = n9 - n3;
assert_eq!(exp, act);
}
#[test]
fn test_mul() {
MclInitializer::init();
let n3 = MclFr::from(3);
let n9 = MclFr::from(9);
let exp = MclFr::from(27);
let act = n3 * n9;
assert_eq!(exp, act);
}
#[test]
fn test_inv() {
MclInitializer::init();
let n1 = MclFr::from(1);
let n9 = MclFr::from(9);
let inv9 = n9.inv();
assert_eq!(n9 * inv9, n1);
}
#[test]
fn test_sq() {
MclInitializer::init();
let n3 = MclFr::from(3);
let n9 = MclFr::from(9);
assert_eq!(n3.sq(), n9);
}
#[test]
fn test_neg() {
MclInitializer::init();
let n9 = &MclFr::from(9);
assert_eq!(n9 + -n9, MclFr::zero());
}
#[test]
fn test_inc() {
MclInitializer::init();
let n1 = MclFr::from(1);
let n2 = MclFr::from(2);
let n3 = MclFr::from(3);
let mut n = MclFr::zero();
assert!(n.is_zero());
n.inc();
assert_eq!(n, n1);
n.inc();
assert_eq!(n, n2);
n.inc();
assert_eq!(n, n3);
}
#[test]
fn test_ord() {
MclInitializer::init();
let n2 = &MclFr::from(2);
let n3 = &MclFr::from(3);
assert!((n2 == n2) == true);
assert!((n2 != n2) == false);
assert!((n2 < n2) == false);
assert!((n2 > n2) == false);
assert!((n2 >= n2) == true);
assert!((n2 <= n2) == true);
assert!((n2 == n3) == false);
assert!((n2 != n3) == true);
assert!((n2 < n3) == true);
assert!((n2 > n3) == false);
assert!((n2 >= n3) == false);
assert!((n2 <= n3) == true);
}
#[test]
fn test_hashing() {
MclInitializer::init();
use std::collections::HashMap;
let n2 = &MclFr::from(2);
let n3 = &MclFr::from(3);
let n4 = &MclFr::from(3);
let m = HashMap::<String, MclFr>::from([
("2".to_string(), n2.clone()),
("3".to_string(), n3.clone()),
("4".to_string(), n4.clone()),
]);
assert_eq!(m.get("2").unwrap(), n2);
assert_eq!(m.get("3").unwrap(), n3);
assert_eq!(m.get("4").unwrap(), n4);
}
}

View File

@@ -0,0 +1,186 @@
use mcl_rust::*;
use std::ops::{Add, Mul, Neg, AddAssign};
use num_traits::Zero;
use once_cell::sync::Lazy;
use crate::building_block::mcl::mcl_fr::MclFr;
#[derive(Clone, Debug)]
pub struct MclG1 {
pub v: G1,
}
static GENERATOR: Lazy<MclG1> = Lazy::new(|| {
let serialized_g = "1 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569";
let mut v = G1::zero();
G1::set_str(&mut v, serialized_g, 10);
MclG1 { v }
});
impl MclG1 {
pub fn new() -> Self {
let v = G1::zero();
MclG1::from(&v)
}
pub fn g() -> MclG1 {
GENERATOR.clone()
}
pub fn inv(&self) -> Self {
let mut v = G1::zero();
G1::neg(&mut v, &self.v);
MclG1::from(&v)
}
pub fn get_random_point() -> MclG1 {
let mut v = Fr::zero();
v.set_by_csprng();
MclG1::g() * MclFr::from(&v)
}
}
impl Zero for MclG1 {
fn zero() -> MclG1 {
let v = G1::zero();
MclG1::from(&v)
}
fn is_zero(&self) -> bool {
self.v.is_zero()
}
}
impl From<&G1> for MclG1 {
fn from(v: &G1) -> Self {
MclG1 { v: v.clone() }
}
}
macro_rules! impl_add {
($rhs: ty, $target: ty) => {
impl Add<$rhs> for $target {
type Output = MclG1;
fn add(self, rhs: $rhs) -> Self::Output {
let mut v = G1::zero();
G1::add(&mut v, &self.v, &rhs.v);
MclG1::from(&v)
}
}
};
}
impl_add!(MclG1, MclG1);
impl_add!(&MclG1, MclG1);
impl_add!(MclG1, &MclG1);
impl_add!(&MclG1, &MclG1);
macro_rules! impl_mul {
($rhs: ty, $target: ty) => {
impl Mul<$rhs> for $target {
type Output = MclG1;
fn mul(self, rhs: $rhs) -> Self::Output {
let mut v = G1::zero();
G1::mul(&mut v, &self.v, &rhs.v);
MclG1::from(&v)
}
}
};
}
impl_mul!(MclFr, MclG1);
impl_mul!(&MclFr, MclG1);
impl_mul!(MclFr, &MclG1);
impl_mul!(&MclFr, &MclG1);
impl AddAssign<MclG1> for MclG1 {
fn add_assign(&mut self, rhs: MclG1) {
*self = &*self + rhs
}
}
impl PartialEq for MclG1 {
fn eq(&self, rhs: &Self) -> bool {
self.v == rhs.v
}
}
impl Eq for MclG1 {}
macro_rules! impl_neg {
($target: ty) => {
impl Neg for $target {
type Output = MclG1;
fn neg(self) -> Self::Output {
let mut v = G1::zero();
G1::neg(&mut v, &self.v);
MclG1::from(&v)
}
}
}
}
impl_neg!(MclG1);
impl_neg!(&MclG1);
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn equality() {
MclInitializer::init();
let g = MclG1::g();
let g2 = &g + &g;
assert_eq!(&g, &g);
assert_eq!(&g2, &g2);
assert_ne!(&g, &g2);
}
#[test]
fn add() {
MclInitializer::init();
let g = &MclG1::g();
let g2 = &(g + g);
let g4 = &(g2 + g2);
{
let act = g + g;
let exp = g2;
assert_eq!(&act, exp);
}
{
let act = g2 + g2;
let exp = g4;
assert_eq!(&act, exp);
}
}
#[test]
fn scalar_mul() {
MclInitializer::init();
let g = &MclG1::g();
let n4 = MclFr::from(4);
let act = g * n4;
let exp = g + g + g + g;
assert_eq!(act, exp);
}
#[test]
fn neg() {
MclInitializer::init();
let g = &MclG1::g();
let n4 = &MclFr::from(4);
let g_n4 = g * n4;
let g_n4_neg = (g * n4).neg();
let act = g_n4 + g_n4_neg;
let exp = MclG1::zero();
assert_eq!(act, exp);
}
}

View File

@@ -0,0 +1,199 @@
use mcl_rust::*;
use std::ops::{Add, Mul, Neg, AddAssign};
use num_traits::Zero;
use once_cell::sync::Lazy;
use crate::building_block::mcl::mcl_fr::MclFr;
#[derive(Clone, Debug)]
pub struct MclG2 {
pub v: G2,
}
static GENERATOR: Lazy<MclG2> = Lazy::new(|| {
let serialized_g = "1 352701069587466618187139116011060144890029952792775240219908644239793785735715026873347600343865175952761926303160 3059144344244213709971259814753781636986470325476647558659373206291635324768958432433509563104347017837885763365758 1985150602287291935568054521177171638300868978215655730859378665066344726373823718423869104263333984641494340347905 927553665492332455747201965776037880757740193453592970025027978793976877002675564980949289727957565575433344219582";
let mut v = G2::zero();
G2::set_str(&mut v, serialized_g, 10);
MclG2 { v }
});
impl MclG2 {
pub fn new() -> Self {
let v = G2::zero();
MclG2::from(&v)
}
pub fn g() -> MclG2 {
GENERATOR.clone()
}
pub fn inv(&self) -> Self {
let mut v = G2::zero();
G2::neg(&mut v, &self.v);
MclG2::from(&v)
}
pub fn get_random_point() -> MclG2 {
let mut v = Fr::zero();
v.set_by_csprng();
MclG2::g() * MclFr::from(&v)
}
pub fn hash_and_map(buf: &Vec<u8>) -> MclG2 {
let mut v = G2::zero();
G2::set_hash_of(&mut v, buf);
MclG2::from(&v)
}
}
impl Zero for MclG2 {
fn zero() -> MclG2 {
let v = G2::zero();
MclG2::from(&v)
}
fn is_zero(&self) -> bool {
self.v.is_zero()
}
}
impl From<&G2> for MclG2 {
fn from(v: &G2) -> Self {
MclG2 { v: v.clone() }
}
}
macro_rules! impl_add {
($rhs: ty, $target: ty) => {
impl Add<$rhs> for $target {
type Output = MclG2;
fn add(self, rhs: $rhs) -> Self::Output {
let mut v = G2::zero();
G2::add(&mut v, &self.v, &rhs.v);
MclG2::from(&v)
}
}
};
}
impl_add!(MclG2, MclG2);
impl_add!(&MclG2, MclG2);
impl_add!(MclG2, &MclG2);
impl_add!(&MclG2, &MclG2);
macro_rules! impl_mul {
($rhs: ty, $target: ty) => {
impl Mul<$rhs> for $target {
type Output = MclG2;
fn mul(self, rhs: $rhs) -> Self::Output {
let mut v = G2::zero();
G2::mul(&mut v, &self.v, &rhs.v);
MclG2::from(&v)
}
}
};
}
impl_mul!(MclFr, MclG2);
impl_mul!(&MclFr, MclG2);
impl_mul!(MclFr, &MclG2);
impl_mul!(&MclFr, &MclG2);
impl AddAssign<MclG2> for MclG2 {
fn add_assign(&mut self, rhs: MclG2) {
*self = &*self + rhs
}
}
impl PartialEq for MclG2 {
fn eq(&self, rhs: &Self) -> bool {
self.v == rhs.v
}
}
impl Eq for MclG2 {}
macro_rules! impl_neg {
($target: ty) => {
impl Neg for $target {
type Output = MclG2;
fn neg(self) -> Self::Output {
let mut v = G2::zero();
G2::neg(&mut v, &self.v);
MclG2::from(&v)
}
}
}
}
impl_neg!(MclG2);
impl_neg!(&MclG2);
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn equality() {
MclInitializer::init();
let g = MclG2::g();
println!("g {:?}", &g);
let g2 = &g + &g;
assert_eq!(&g, &g);
assert_eq!(&g2, &g2);
assert_ne!(&g, &g2);
}
#[test]
fn add() {
MclInitializer::init();
let g = &MclG2::g();
println!("g is {:?}", &g);
let g2 = &(g + g);
let g4 = &(g2 + g2);
{
let act = g + g;
let exp = g2;
assert_eq!(&act, exp);
}
{
let act = g + g;
let exp = g * MclFr::from(2);
assert_eq!(&act, &exp);
}
{
let act = g2 + g2;
let exp = g4;
assert_eq!(&act, exp);
}
}
#[test]
fn scalar_mul() {
MclInitializer::init();
let g = &MclG2::g();
let n4 = MclFr::from(4);
let act = g * n4;
let exp = g + g + g + g;
assert_eq!(act, exp);
}
#[test]
fn neg() {
MclInitializer::init();
let g = &MclG2::g();
let n4 = &MclFr::from(4);
let g_n4 = g * n4;
let g_n4_neg = (g * n4).neg();
let act = g_n4 + g_n4_neg;
let exp = MclG2::zero();
assert_eq!(act, exp);
}
}

View File

@@ -0,0 +1,200 @@
use mcl_rust::*;
use std::{
convert::From,
fmt,
ops::{Add,
Sub,
Mul,
Neg,
},
};
use num_traits::Zero;
#[derive(Debug, Clone)]
pub struct MclGT {
pub v: GT,
}
impl MclGT {
pub fn new() -> Self {
let v = GT::zero();
MclGT::from(&v)
}
pub fn inv(&self) -> Self {
let mut v = GT::zero();
GT::inv(&mut v, &self.v);
MclGT::from(&v)
}
pub fn sq(&self) -> Self {
let mut v = GT::zero();
GT::sqr(&mut v, &self.v);
MclGT::from(&v)
}
}
impl Zero for MclGT {
fn is_zero(&self) -> bool {
self.v.is_zero()
}
fn zero() -> Self {
MclGT::from(&GT::zero())
}
}
impl From<i32> for MclGT {
fn from(value: i32) -> Self {
let v = GT::from_int(value);
MclGT { v }
}
}
impl From<&GT> for MclGT {
fn from(v: &GT) -> Self {
MclGT { v: v.clone() }
}
}
impl PartialEq for MclGT {
fn eq(&self, other: &Self) -> bool {
self.v == other.v
}
}
impl Eq for MclGT {}
impl fmt::Display for MclGT {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.v.get_str(16))
}
}
macro_rules! impl_neg {
($target: ty) => {
impl Neg for $target {
type Output = MclGT;
fn neg(self) -> Self::Output {
let mut v = GT::zero();
GT::neg(&mut v, &self.v);
MclGT::from(&v)
}
}
}
}
impl_neg!(MclGT);
impl_neg!(&MclGT);
macro_rules! impl_add {
($rhs: ty, $target: ty) => {
impl Add<$rhs> for $target {
type Output = MclGT;
fn add(self, rhs: $rhs) -> Self::Output {
let mut v = GT::zero();
GT::add(&mut v, &self.v, &rhs.v);
MclGT::from(&v)
}
}
};
}
impl_add!(MclGT, MclGT);
impl_add!(&MclGT, MclGT);
impl_add!(MclGT, &MclGT);
impl_add!(&MclGT, &MclGT);
macro_rules! impl_sub {
($rhs: ty, $target: ty) => {
impl Sub<$rhs> for $target {
type Output = MclGT;
fn sub(self, rhs: $rhs) -> Self::Output {
let mut v = GT::zero();
GT::sub(&mut v, &self.v, &rhs.v);
MclGT::from(&v)
}
}
};
}
impl_sub!(MclGT, MclGT);
impl_sub!(&MclGT, MclGT);
impl_sub!(MclGT, &MclGT);
impl_sub!(&MclGT, &MclGT);
macro_rules! impl_mul {
($rhs: ty, $target: ty) => {
impl Mul<$rhs> for $target {
type Output = MclGT;
fn mul(self, rhs: $rhs) -> Self::Output {
let mut v = GT::zero();
GT::mul(&mut v, &self.v, &rhs.v);
MclGT { v }
}
}
};
}
impl_mul!(MclGT, MclGT);
impl_mul!(&MclGT, MclGT);
impl_mul!(MclGT, &MclGT);
impl_mul!(&MclGT, &MclGT);
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn test_add() {
MclInitializer::init();
let n3 = MclGT::from(3i32);
let n9 = MclGT::from(9i32);
let exp = MclGT::from(12i32);
let act = n3 + n9;
assert_eq!(exp, act);
}
#[test]
fn test_sub() {
MclInitializer::init();
let n9 = MclGT::from(9i32);
let n3 = MclGT::from(3i32);
let exp = MclGT::from(6i32);
let act = n9 - n3;
assert_eq!(exp, act);
}
#[test]
fn test_mul() {
MclInitializer::init();
let n3 = MclGT::from(3i32);
let n9 = MclGT::from(9i32);
let exp = MclGT::from(27i32);
let act = n3 * n9;
assert_eq!(exp, act);
}
// #[test]
// fn test_inv() {
// MclInitializer::init();
//
// let n1 = MclGT::from(1i32);
// let n9 = MclGT::from(9i32);
// let inv9 = n9.inv();
//
// assert_eq!(n9 * inv9, n1);
// }
#[test]
fn test_neg() {
MclInitializer::init();
let n9 = &MclGT::from(9i32);
assert_eq!(n9 + -n9, MclGT::zero());
}
}

View File

@@ -0,0 +1,26 @@
use mcl_rust::*;
use std::sync::Once;
static INIT: Once = Once::new();
pub struct MclInitializer;
impl MclInitializer {
pub fn init() {
INIT.call_once(|| {
if !init(CurveType::BLS12_381) {
panic!("Failed to initialize mcl");
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn init() {
MclInitializer::init();
}
}

View File

@@ -0,0 +1,824 @@
use num_traits::Zero;
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_sparse_vec::MclSparseVec,
polynomial::Polynomial,
};
use std::{
collections::HashMap,
convert::From,
fmt,
ops::Mul,
};
pub struct MclSparseMatrix {
pub width: MclFr,
pub height: MclFr,
rows: HashMap<MclFr, MclSparseVec>,
zero: MclFr,
}
impl MclSparseMatrix {
pub fn new(width: &MclFr, height: &MclFr) -> Self {
let zero = MclFr::zero();
let rows = HashMap::new();
Self {
width: width.clone(),
height: height.clone(),
rows,
zero,
}
}
pub fn pretty_print(&self) -> String {
let mut s = String::new();
let empty_row = &MclSparseVec::new(&self.width);
let mut y = MclFr::zero();
while y < self.height {
match self.rows.get(&y) {
Some(row) => {
s = format!("{}{}\n", s, row.pretty_print());
},
None => {
s = format!("{}{}\n", s, empty_row.pretty_print());
}
}
y.inc();
}
s
}
pub fn multiply_column(&self, col: &MclSparseVec) -> Self {
if col.size != self.height {
panic!("column size is expected to be {:?}, but got {:?}",
self.height, col.size)
}
let mut m = MclSparseMatrix::new(&self.width, &self.height);
let mut y = MclFr::from(0);
while y < col.size {
let mut x = MclFr::from(0);
let multiplier = col.get(&y);
while x < self.width {
let v = self.get(&x, &y) * multiplier;
m.set(&x, &y, &v);
x.inc();
}
y.inc();
}
m
}
pub fn flatten_rows(&self) -> MclSparseVec {
let mut vec = MclSparseVec::new(&self.width);
let mut y = MclFr::from(0);
while y < self.height {
println!("y={:?}", &y);
let mut x = MclFr::from(0);
while x < self.width {
println!("x={:?}", &x);
let v = vec.get(&x) + self.get(&x, &y);
vec.set(&x, &v);
x.inc();
}
y.inc();
}
vec
}
pub fn set(&mut self, x: &MclFr, y: &MclFr, v: &MclFr) -> () {
if x >= &self.width || y >= &self.height {
panic!("For {:?} x {:?} matrix, ({:?}, {:?}) is out of range",
self.width, self.height, x, y);
}
if v.is_zero() { // don't set if zero
return;
}
if !self.rows.contains_key(&y) {
let vec = MclSparseVec::new(&self.width);
self.rows.insert(y.clone(), vec);
}
self.rows.get_mut(&y).unwrap().set(&x, &v);
}
pub fn get(&self, x: &MclFr, y: &MclFr) -> &MclFr {
if x >= &self.width || y >= &self.height {
panic!("For {:?} x {:?} matrix, ({:?}, {:?}) is out of range",
self.width, self.height, x, y);
}
if !self.rows.contains_key(&y) {
&self.zero
} else {
self.rows.get(&y).unwrap().get(&x)
}
}
pub fn get_row(&self, y: &MclFr) -> MclSparseVec {
assert!(y < &self.height);
let mut row = MclSparseVec::new(&self.width);
if !self.rows.contains_key(y) {
return row;
}
let src_row = self.rows.get(y).unwrap();
for x in src_row.indices() {
let v = src_row.get(&x);
if !v.is_zero() {
row.set(&x, v);
}
}
row
}
pub fn get_column(&self, x: &MclFr) -> MclSparseVec {
assert!(x < &self.width);
let mut col = MclSparseVec::new(&self.height);
for y in self.rows.keys() {
let src_row = self.rows.get(&y).unwrap();
let v = src_row.get(x);
if !v.is_zero() {
col.set(y, v);
}
}
col
}
pub fn transpose(&self) -> MclSparseMatrix {
let mut m = MclSparseMatrix::new(&self.height, &self.width);
for y in self.rows.keys() {
let src_row = self.rows.get(&y).unwrap();
for x in src_row.indices() {
let v = src_row.get(&x);
if !v.is_zero() {
m.set(y, &x, v);
}
}
}
m
}
// remove empty rows
pub fn normalize(&self) -> MclSparseMatrix {
let mut m = MclSparseMatrix::new(&self.width, &self.height);
for row_key in self.rows.keys() {
let row = self.rows.get(row_key).unwrap();
if !row.is_empty() {
m.rows.insert(row_key.clone(), row.clone());
}
}
m
}
pub fn row_transform(&self, transform: Box<dyn Fn(&MclSparseVec) -> MclSparseVec>) -> MclSparseMatrix {
let mut m = MclSparseMatrix::new(&self.width, &self.height);
let mut y = MclFr::zero();
while y < self.height {
let in_row = self.get_row(&y);
let out_row = transform(&in_row);
let mut x = MclFr::zero();
while x < self.width {
let v = out_row.get(&x);
m.set(&x, &y, v);
x.inc();
}
y.inc();
}
m
}
}
impl PartialEq for MclSparseMatrix {
fn eq(&self, other: &MclSparseMatrix) -> bool {
if self.width != other.width || self.height != other.height {
return false;
}
for key in self.rows.keys() {
if !other.rows.contains_key(key) {
return false;
}
let self_row = &self.rows[key];
let other_row = &other.rows[key];
if self_row != other_row {
return false;
}
}
for key in other.rows.keys() {
if !self.rows.contains_key(key) {
return false;
}
let self_row = &self.rows[key];
let other_row = &other.rows[key];
if self_row != other_row {
return false;
}
}
true
}
}
impl Into<Vec<Polynomial>> for MclSparseMatrix {
fn into(self) -> Vec<Polynomial> {
let mut vec = vec![];
let mut i = MclFr::zero();
while &i < &self.height {
let p = Polynomial::from(&self.get_row(&i));
vec.push(p);
i.inc();
}
vec
}
}
// coverts rows of vectors to a matrix
impl From<&Vec<MclSparseVec>> for MclSparseMatrix {
fn from(rows: &Vec<MclSparseVec>) -> Self {
assert!(rows.len() != 0, "Cannot build matrix from empty vector");
let width = &rows[0].size;
let height = rows.len();
for i in 1..height {
if width != &rows[i].size {
panic!("different row sizes found; size is {:?} at 0, but {:?} at {}",
width, &rows[i].size, i)
}
}
let mut m = MclSparseMatrix::new(width, &MclFr::from(height));
for (y, row) in rows.iter().enumerate() {
for x in row.indices() {
let v = row.get(&x);
if !v.is_zero() {
m.set(&x, &MclFr::from(y), v);
}
}
}
m.normalize()
}
}
impl Mul<&MclSparseMatrix> for &MclSparseMatrix {
type Output = MclSparseMatrix;
fn mul(self, rhs: &MclSparseMatrix) -> Self::Output {
if self.width != rhs.height {
panic!("Can only multiply matrix with height {:?}, but the rhs height is {:?}",
self.width, rhs.height);
}
let mut res = MclSparseMatrix::new(&rhs.width, &self.height);
let mut y = MclFr::zero();
while y < self.height {
let mut x = MclFr::zero();
while x < rhs.width {
let lhs = &self.get_row(&y);
let rhs = &rhs.get_column(&x);
let v = (lhs * rhs).sum();
if !v.is_zero() {
res.set(&x, &y, &v);
}
x.inc();
}
y.inc();
}
res.normalize()
}
}
impl fmt::Debug for MclSparseMatrix {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", &self.pretty_print())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn test_size() {
MclInitializer::init();
let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3));
assert_eq!(m.width, MclFr::from(2));
assert_eq!(m.height, MclFr::from(3));
}
#[test]
#[should_panic]
fn test_get_x_out_of_range() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3));
m.get(&MclFr::from(2), &MclFr::from(1));
}
#[test]
#[should_panic]
fn test_get_y_out_of_range() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3));
m.get(&MclFr::from(1), &MclFr::from(3));
}
#[test]
#[should_panic]
fn test_set_x_out_of_range() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let mut m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3));
m.set(&MclFr::from(2), &MclFr::from(1), &MclFr::from(12));
}
#[test]
#[should_panic]
fn test_set_y_out_of_range() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let mut m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3));
m.set(&MclFr::from(1), &MclFr::from(3), &MclFr::from(12));
}
#[test]
fn test_get_empty() {
MclInitializer::init();
let zero = &MclFr::from(0);
let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3));
for x in 0..2 {
for y in 0..3 {
assert_eq!(m.get(&MclFr::from(x), &MclFr::from(y)), zero);
}
}
}
#[test]
fn test_mul() {
MclInitializer::init();
let zero = &MclFr::from(0);
let m = MclSparseMatrix::new(&MclFr::from(2), &MclFr::from(3));
for x in 0..2 {
for y in 0..3 {
assert_eq!(m.get(&MclFr::from(x), &MclFr::from(y)), zero);
}
}
}
#[test]
fn test_get_existing_and_non_existing_cells() {
MclInitializer::init();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let eight = &MclFr::from(8);
let nine = &MclFr::from(9);
let mut m = MclSparseMatrix::new(two, three);
m.set(zero, two, nine);
m.set(one, one, eight);
assert_eq!(m.get(zero, two), nine);
assert_eq!(m.get(one, one), eight);
assert_eq!(m.get(one, two), zero);
}
#[test]
#[should_panic]
fn test_from_empty_vec() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let _ = MclSparseMatrix::from(&vec![]);
}
// |1 0|
// |0 1|
fn gen_test_2x2_identity_matrix() -> MclSparseMatrix {
MclInitializer::init();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let mut v1 = MclSparseVec::new(&MclFr::from(2));
let mut v2 = MclSparseVec::new(&MclFr::from(2));
v1.set(zero, one);
v2.set(one, one);
let vecs = vec![v1, v2];
MclSparseMatrix::from(&vecs)
}
// |1 2|
// |3 4|
fn gen_test_2x2_matrix() -> MclSparseMatrix {
MclInitializer::init();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let four = &MclFr::from(4);
let mut v1 = MclSparseVec::new(&MclFr::from(2));
let mut v2 = MclSparseVec::new(&MclFr::from(2));
v1.set(zero, one);
v1.set(one, two);
v2.set(zero, three);
v2.set(one, four);
let vecs = vec![v1, v2];
MclSparseMatrix::from(&vecs)
}
// |1 0|
// |0 2|
// |3 0|
fn gen_test_2x3_matrix() -> MclSparseMatrix {
MclInitializer::init();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let mut v1 = MclSparseVec::new(&MclFr::from(2));
let mut v2 = MclSparseVec::new(&MclFr::from(2));
let mut v3 = MclSparseVec::new(&MclFr::from(2));
v1.set(zero, one);
v2.set(one, two);
v3.set(zero, three);
let vecs = vec![v1, v2, v3];
MclSparseMatrix::from(&vecs)
}
// |1 2 3|
// |3 2 1|
fn gen_test_3x2_matrix() -> MclSparseMatrix {
MclInitializer::init();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let mut v1 = MclSparseVec::new(&MclFr::from(3));
let mut v2 = MclSparseVec::new(&MclFr::from(3));
v1.set(zero, one);
v1.set(one, two);
v1.set(two, three);
v2.set(zero, three);
v2.set(one, two);
v2.set(two, one);
let vecs = vec![v1, v2];
MclSparseMatrix::from(&vecs)
}
// |1 2|
// |0 0|
fn gen_test_2x2_redundant_matrix(use_empty_row: bool) -> MclSparseMatrix {
MclInitializer::init();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let mut v1 = MclSparseVec::new(&MclFr::from(2));
v1.set(zero, one);
v1.set(one, two);
if use_empty_row {
let mut rows = HashMap::<MclFr, MclSparseVec>::new();
rows.insert(zero.clone(), v1.clone());
rows.insert(one.clone(), MclSparseVec::new(&MclFr::from(2)));
MclSparseMatrix {
width: two.clone(),
height: two.clone(),
rows,
zero: zero.clone(),
}
} else {
let mut rows = HashMap::<MclFr, MclSparseVec>::new();
rows.insert(zero.clone(), v1.clone());
MclSparseMatrix {
width: two.clone(),
height: two.clone(),
rows,
zero: zero.clone(),
}
}
}
#[test]
fn test_eq() {
MclInitializer::init();
{
let m1 = gen_test_2x2_identity_matrix();
let m2 = gen_test_2x2_matrix();
let m3 = gen_test_3x2_matrix();
assert!(&m1 == &m1);
assert!(&m2 == &m2);
assert!(&m3 == &m3);
}
{
let m1 = gen_test_3x2_matrix();
let m2 = gen_test_3x2_matrix();
assert!(&m1 == &m2);
assert!(&m2 == &m1);
}
}
#[test]
fn test_eq_with_redundant_matrix() {
MclInitializer::init();
let m1 = gen_test_2x2_redundant_matrix(true);
let m2 = gen_test_2x2_redundant_matrix(false);
assert!(&m1 != &m2);
assert!(&m2 != &m1);
}
#[test]
fn test_non_eq() {
MclInitializer::init();
let m1 = gen_test_2x2_identity_matrix();
let m2 = gen_test_2x2_matrix();
let m3 = gen_test_3x2_matrix();
assert!(&m1 != &m2);
assert!(&m2 != &m1);
assert!(&m1 != &m3);
assert!(&m3 != &m1);
assert!(&m2 != &m1);
assert!(&m1 != &m2);
assert!(&m2 != &m3);
assert!(&m3 != &m2);
assert!(&m3 != &m1);
assert!(&m1 != &m3);
assert!(&m3 != &m2);
assert!(&m2 != &m3);
}
#[test]
fn test_normalize() {
MclInitializer::init();
let m1 = gen_test_2x2_redundant_matrix(true);
let m2 = gen_test_2x2_redundant_matrix(false);
let m1 = m1.normalize();
assert!(&m1 == &m2);
assert!(&m2 == &m1);
}
#[test]
fn test_from_sparse_vecs() {
MclInitializer::init();
let m = gen_test_2x3_matrix();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
assert_eq!(&m.width, two);
assert_eq!(&m.height, three);
assert_eq!(m.get(zero, zero), one);
assert_eq!(m.get(one, one), two);
assert_eq!(m.get(zero, two), three);
}
#[test]
#[should_panic]
fn test_get_row_x_out_of_range() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let m = gen_test_2x3_matrix();
let _ = m.get_row(&m.height);
}
#[test]
fn test_get_row_x_within_range() {
MclInitializer::init();
let m = gen_test_2x3_matrix();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let r0 = m.get_row(zero);
assert_eq!(r0.get(zero), one);
assert_eq!(r0.get(one), zero);
let r1 = m.get_row(one);
assert_eq!(r1.get(zero), zero);
assert_eq!(r1.get(one), two);
let r2 = m.get_row(two);
assert_eq!(r2.get(zero), three);
assert_eq!(r2.get(one), zero);
}
#[test]
#[should_panic]
fn test_get_column_out_of_range() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let m = gen_test_2x3_matrix();
let _ = m.get_column(&m.width);
}
#[test]
fn test_get_column_within_range() {
MclInitializer::init();
let m = gen_test_2x3_matrix();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let c0 = m.get_column(zero);
assert_eq!(c0.get(zero), one);
assert_eq!(c0.get(one), zero);
assert_eq!(c0.get(two), three);
let c1 = m.get_column(one);
assert_eq!(c1.get(zero), zero);
assert_eq!(c1.get(one), two);
assert_eq!(c1.get(two), zero);
}
#[test]
#[should_panic]
fn test_mul_incompatible_matrices() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let m = gen_test_2x3_matrix();
let _ = &m * &m;
}
#[test]
fn test_mul_different_sizes() {
MclInitializer::init();
let m1 = gen_test_3x2_matrix();
let m2 = gen_test_2x3_matrix();
let m3 = &m1 * &m2;
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let six = &MclFr::from(6);
let four = &MclFr::from(4);
let ten = &MclFr::from(10);
assert_eq!(&m3.width, two);
assert_eq!(&m3.height, two);
assert_eq!(m3.get(zero, zero), ten);
assert_eq!(m3.get(one, zero), four);
assert_eq!(m3.get(zero, one), six);
assert_eq!(m3.get(one, one), four);
}
#[test]
fn test_mul_identity() {
MclInitializer::init();
let m1 = gen_test_2x2_matrix();
let m2 = gen_test_2x2_identity_matrix();
let m3 = &m1 * &m2;
let two = &MclFr::from(2);
assert_eq!(&m3.width, two);
assert_eq!(&m3.height, two);
assert!(&m1 == &m3);
}
#[test]
fn test_transpose() {
MclInitializer::init();
let m = &gen_test_3x2_matrix();
let mt = &m.transpose();
assert_eq!(m.height, mt.width);
assert_eq!(m.width, mt.height);
let mut x = MclFr::zero();
while x < m.width {
let mut y = MclFr::zero();
while y < m.height {
let m_v = m.get(&x, &y);
let mt_v = mt.get(&y, &x);
assert_eq!(m_v, mt_v);
y.inc();
}
x.inc();
}
}
#[test]
fn test_row_transform() {
MclInitializer::init();
// |1 2 3|
// |3 2 1|
let m = gen_test_3x2_matrix();
let transform = move |in_vec: &MclSparseVec| {
let one = MclFr::from(1);
let mut out_vec = MclSparseVec::new(&in_vec.size);
let mut i = MclFr::from(0);
while i < in_vec.size {
let v = in_vec.get(&i) + &one;
out_vec.set(&i, &v);
i.inc();
}
out_vec
};
let m = m.row_transform(Box::new(transform));
{
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let four = &MclFr::from(4);
// |2 3 4|
// |4 3 2|
assert_eq!(m.get(zero, zero), two);
assert_eq!(m.get(one, zero), three);
assert_eq!(m.get(two, zero), four);
assert_eq!(m.get(zero, one), four);
assert_eq!(m.get(one, one), three);
assert_eq!(m.get(two, one), two);
}
}
#[test]
fn test_multiply_column() {
MclInitializer::init();
// |1 2 3|
// |3 2 1|
let m = gen_test_3x2_matrix();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let four = &MclFr::from(4);
let six = &MclFr::from(6);
let nine = &MclFr::from(9);
// |2|
// |3|
let mut col = MclSparseVec::new(&m.height);
col.set(zero, two);
col.set(one, three);
let m = m.multiply_column(&col);
// |2 4 6|
// |9 6 3|
assert_eq!(m.get(zero, zero), two);
assert_eq!(m.get(one, zero), four);
assert_eq!(m.get(two, zero), six);
assert_eq!(m.get(zero, one), nine);
assert_eq!(m.get(one, one), six);
assert_eq!(m.get(two, one), three);
}
#[test]
fn test_flatten_rows() {
MclInitializer::init();
// |1 2 3|
// |3 2 1|
let m = gen_test_3x2_matrix();
let zero = &MclFr::from(0);
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let four = &MclFr::from(4);
let row = m.flatten_rows();
// |4 2 4|
assert_eq!(row.get(zero), four);
assert_eq!(row.get(one), four);
assert_eq!(row.get(two), four);
}
}

View File

@@ -0,0 +1,559 @@
use std::{
collections::HashMap,
convert::From,
ops::Mul,
};
use crate::building_block::mcl::mcl_fr::MclFr;
use num_traits::Zero;
use core::ops::{Index, IndexMut};
#[derive(Clone)]
pub struct MclSparseVec {
pub size: MclFr,
zero: MclFr,
elems: HashMap<MclFr, MclFr>, // HashMap<index, value>
}
impl std::fmt::Debug for MclSparseVec {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
let s = self.pretty_print();
write!(fmt, "{}", s)
}
}
pub struct MclSparseVecIterator<'a> {
sv: &'a MclSparseVec,
i: MclFr,
}
impl<'a> Iterator for MclSparseVecIterator<'a> {
type Item = MclFr;
fn next(&mut self) -> Option<Self::Item> {
if &self.sv.size == &self.i {
None
} else {
let elem = self.sv[&self.i].clone();
self.i.inc();
Some(elem)
}
}
}
impl MclSparseVec {
pub fn new(size: &MclFr) -> Self {
if size.is_zero() {
panic!("Size must be greater than 0");
}
MclSparseVec {
zero: MclFr::zero(),
size: size.clone(),
elems: HashMap::<MclFr, MclFr>::new(),
}
}
pub fn iter(&self) -> MclSparseVecIterator {
MclSparseVecIterator { sv: self, i: MclFr::zero() }
}
pub fn set(&mut self, index: &MclFr, n: &MclFr) {
if index >= &self.size {
panic!("Index {:?} is out of range. The size of vector is {:?}", index, self.size);
}
if !n.is_zero() {
self.elems.insert(index.clone(), n.clone());
}
}
pub fn get(&self, index: &MclFr) -> &MclFr {
if index >= &self.size {
panic!("Index {:?} is out of range. The size of vector is {:?}", index, self.size);
}
if self.elems.contains_key(index) {
self.elems.get(index).unwrap()
} else {
&self.zero
}
}
pub fn indices(&self) -> Vec<MclFr> {
let mut vec = vec![];
for x in self.elems.keys() {
vec.push(x.clone());
}
vec
}
// TODO clean up
pub fn sum(&self) -> MclFr {
let mut values = vec![];
for value in self.elems.values() {
values.push(value);
}
let mut sum = values[0].clone();
for value in &values[1..] {
sum = sum + *value;
}
sum
}
// empty if containing only zeros
pub fn is_empty(&self) -> bool {
for value in self.elems.values() {
if !value.is_zero() {
return false;
}
}
true
}
pub fn pretty_print(&self) -> String {
let mut s = "[".to_string();
let mut i = MclFr::zero();
let one = &MclFr::from(1);
while &i < &self.size {
s += &format!("{:?}", self.get(&i));
if &i < &(&self.size - one) {
s += ",";
}
i.inc();
}
s += "]";
s
}
// returns a vector of range [from..to)
pub fn slice(&self, from: &MclFr, to: &MclFr) -> Self {
let size = to - from;
let mut new_sv = MclSparseVec::new(&size);
let mut i = from.clone();
while &i < &to {
new_sv.set(&(&i - from), &self[&i]);
i.inc();
}
new_sv
}
pub fn concat(&self, other: &MclSparseVec) -> MclSparseVec {
let size = &self.size + &other.size;
let mut sv = MclSparseVec::new(&size);
let mut i = MclFr::zero();
// copy self to new sv
{
let mut j = MclFr::zero();
while &j < &self.size {
sv[&i] = self[&j].clone();
j.inc();
i.inc();
}
}
// copy other to new sv
{
let mut j = MclFr::zero();
while &j < &other.size {
sv[&i] = other[&j].clone();
j.inc();
i.inc();
}
}
sv
}
}
impl PartialEq for MclSparseVec {
fn eq(&self, other: &MclSparseVec) -> bool {
if self.size != other.size { return false; }
for index in self.elems.keys() {
let other_elem = other.get(index);
let this_elem = self.get(index);
if this_elem != other_elem { return false; }
}
for index in other.elems.keys() {
let other_elem = other.get(index);
let this_elem = self.get(index);
if this_elem != other_elem { return false; }
}
true
}
}
impl Index<&MclFr> for MclSparseVec {
type Output = MclFr;
fn index(&self, index: &MclFr) -> &Self::Output {
&self.get(index)
}
}
impl IndexMut<&MclFr> for MclSparseVec {
fn index_mut(&mut self, index: &MclFr) -> &mut Self::Output {
if !self.elems.contains_key(index) {
self.elems.insert(index.clone(), MclFr::from(0));
}
self.elems.get_mut(index).unwrap()
}
}
impl From<&Vec<MclFr>> for MclSparseVec {
fn from(elems: &Vec<MclFr>) -> Self {
assert!(elems.len() != 0, "Cannot build vector from empty element list");
let size = MclFr::from(elems.len());
let mut vec = MclSparseVec::new(&size);
for (i, v) in elems.iter().enumerate() {
if !v.is_zero() {
vec.set(&MclFr::from(i), v);
}
}
vec
}
}
// returns Hadamard product
impl Mul<&MclSparseVec> for &MclSparseVec {
type Output = MclSparseVec;
fn mul(self, rhs: &MclSparseVec) -> Self::Output {
if self.size != rhs.size {
panic!("Expected size of rhs to be {:?}, but got {:?}", self.size, rhs.size);
}
let mut ret = MclSparseVec::new(&self.size);
for index in self.elems.keys() {
let l = self.get(index);
let r = rhs.get(index);
if !l.is_zero() && !r.is_zero() {
ret.set(index, &(l * r));
}
}
ret
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
#[should_panic]
fn test_from_empty_list() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let _ = MclSparseVec::from(&vec![]);
}
#[test]
fn test_slice() {
MclInitializer::init();
let zero = &MclFr::zero();
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let elems = vec![
zero.clone(),
one.clone(),
two.clone(),
three.clone(),
];
let sv = MclSparseVec::from(&elems);
let sv2 = sv.slice(one, three);
assert_eq!(&sv2.size, two);
assert_eq!(&sv2[zero], one);
assert_eq!(&sv2[one], two);
}
#[test]
fn test_from_non_empty_list() {
MclInitializer::init();
let zero = &MclFr::zero();
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let elems = vec![one.clone(), two.clone()];
let vec = MclSparseVec::from(&elems);
assert_eq!(&vec.size, two);
assert_eq!(&vec[zero], one);
assert_eq!(&vec[one], two);
}
#[test]
fn test_from_non_empty_zero_only_list() {
MclInitializer::init();
let zero = &MclFr::zero();
let two = &MclFr::from(2);
let elems = vec![zero.clone(), zero.clone()];
let vec = MclSparseVec::from(&elems);
assert_eq!(&vec.size, two);
assert_eq!(vec.elems.len(), 0);
}
#[test]
#[should_panic]
fn test_new_empty_vec() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
MclSparseVec::new(&MclFr::zero());
}
#[test]
#[should_panic]
fn test_bad_set() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let mut vec = MclSparseVec::new(&MclFr::from(3));
assert_eq!(vec.elems.len(), 0);
vec.set(&MclFr::from(3), &MclFr::from(2));
}
#[test]
fn test_good_set() {
MclInitializer::init();
let mut vec = MclSparseVec::new(&MclFr::from(3));
assert_eq!(vec.elems.len(), 0);
let two = &MclFr::from(2);
vec.set(&MclFr::from(2), two);
assert_eq!(vec.elems.len(), 1);
assert_eq!(vec.elems.get(two).unwrap(), &MclFr::from(2));
// setting the same index should overwrite
vec.set(&MclFr::from(2), &MclFr::from(3));
assert_eq!(vec.elems.len(), 1);
assert_eq!(vec.elems.get(two).unwrap(), &MclFr::from(3));
// setting 0 should do nothing
vec.set(&MclFr::from(2), &MclFr::zero());
assert_eq!(vec.elems.len(), 1);
assert_eq!(vec.elems.get(two).unwrap(), &MclFr::from(3));
}
#[test]
fn test_assign() {
MclInitializer::init();
let mut vec = MclSparseVec::new(&MclFr::from(3));
assert_eq!(vec.elems.len(), 0);
vec.set(&MclFr::from(2), &MclFr::from(2));
assert_eq!(vec.elems.len(), 1);
assert_eq!(vec.elems.get(&MclFr::from(2)).unwrap(), &MclFr::from(2));
let indices = vec.indices();
assert_eq!(indices.len(), 1);
assert_eq!(indices[0], MclFr::from(2));
// setting the same index should overwrite
vec.set(&MclFr::from(2), &MclFr::from(3));
assert_eq!(vec.elems.len(), 1);
assert_eq!(vec.elems.get(&MclFr::from(2)).unwrap(), &MclFr::from(3));
}
#[test]
fn test_good_get() {
MclInitializer::init();
let zero = &MclFr::zero();
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let three = &MclFr::from(3);
let mut vec = MclSparseVec::new(&MclFr::from(3));
vec.set(zero, one);
vec.set(one, two);
vec.set(two, three);
assert_eq!(vec.get(zero), one);
assert_eq!(vec.get(one), two);
assert_eq!(vec.get(two), three);
}
#[test]
#[should_panic]
fn test_get_index_out_of_range() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let vec = MclSparseVec::new(&MclFr::from(1));
vec.get(&MclFr::from(2));
}
#[test]
fn test_get_index_without_value() {
MclInitializer::init();
std::panic::set_hook(Box::new(|_| {})); // suppress stack trace
let zero = &MclFr::zero();
let vec = MclSparseVec::new(&MclFr::from(3));
assert_eq!(vec.get(&MclFr::from(0)), zero);
assert_eq!(vec.get(&MclFr::from(1)), zero);
assert_eq!(vec.get(&MclFr::from(2)), zero);
}
#[test]
fn test_indices() {
MclInitializer::init();
let mut vec = MclSparseVec::new(&MclFr::from(3));
vec.set(&MclFr::from(1), &MclFr::from(2));
vec.set(&MclFr::from(2), &MclFr::from(4));
let indices = vec.indices();
assert_eq!(indices.len(), 2);
assert!(indices.contains(&MclFr::from(1)));
assert!(indices.contains(&MclFr::from(2)));
}
#[test]
fn test_mutiply_no_matching_elems() {
MclInitializer::init();
let mut vec_a = MclSparseVec::new(&MclFr::from(3));
let mut vec_b = MclSparseVec::new(&MclFr::from(3));
vec_a.set(&MclFr::from(1), &MclFr::from(2));
vec_b.set(&MclFr::from(2), &MclFr::from(3));
let vec_c = &vec_a * &vec_b;
assert_eq!(vec_c.elems.len(), 0);
}
#[test]
fn test_mutiply_elems_fully_matching_1_elem() {
MclInitializer::init();
let mut vec_a = MclSparseVec::new(&MclFr::from(3));
let mut vec_b = MclSparseVec::new(&MclFr::from(3));
vec_a.set(&MclFr::from(1), &MclFr::from(2));
vec_b.set(&MclFr::from(1), &MclFr::from(3));
let vec_c = &vec_a * &vec_b;
assert_eq!(vec_c.elems.len(), 1);
assert_eq!(vec_c.get(&MclFr::from(1)), &MclFr::from(6));
}
#[test]
fn test_mutiply_elems_fully_matching_2_elems() {
MclInitializer::init();
let mut vec_a = MclSparseVec::new(&MclFr::from(3));
let mut vec_b = MclSparseVec::new(&MclFr::from(3));
vec_a.set(&MclFr::from(1), &MclFr::from(2));
vec_a.set(&MclFr::from(2), &MclFr::from(3));
vec_b.set(&MclFr::from(1), &MclFr::from(4));
vec_b.set(&MclFr::from(2), &MclFr::from(5));
let vec_c = &vec_a * &vec_b;
assert_eq!(vec_c.elems.len(), 2);
assert_eq!(vec_c.get(&MclFr::from(1)), &MclFr::from(8));
assert_eq!(vec_c.get(&MclFr::from(2)), &MclFr::from(15));
}
#[test]
fn test_mutiply_elems_partially_matching() {
MclInitializer::init();
let mut vec_a = MclSparseVec::new(&MclFr::from(3));
let mut vec_b = MclSparseVec::new(&MclFr::from(3));
vec_a.set(&MclFr::from(1), &MclFr::from(2));
vec_a.set(&MclFr::from(2), &MclFr::from(5));
vec_b.set(&MclFr::from(1), &MclFr::from(3));
let vec_c = &vec_a * &vec_b;
assert_eq!(vec_c.elems.len(), 1);
assert_eq!(vec_c.get(&MclFr::from(1)), &MclFr::from(6));
}
#[test]
fn test_sum() {
MclInitializer::init();
let mut vec = MclSparseVec::new(&MclFr::from(3));
vec.set(&MclFr::from(1), &MclFr::from(2));
vec.set(&MclFr::from(2), &MclFr::from(4));
let sum = vec.sum();
assert_eq!(sum, MclFr::from(6));
}
#[test]
fn test_eq_different_sizes() {
MclInitializer::init();
let vec_a = MclSparseVec::new(&MclFr::from(3));
let vec_b = MclSparseVec::new(&MclFr::from(4));
assert_ne!(vec_a, vec_b);
assert_ne!(vec_b, vec_a);
}
#[test]
fn test_eq_empty() {
MclInitializer::init();
let vec_a = MclSparseVec::new(&MclFr::from(3));
let vec_b = MclSparseVec::new(&MclFr::from(3));
assert_eq!(vec_a, vec_b);
assert_eq!(vec_b, vec_a);
}
#[test]
fn test_eq_non_empty() {
MclInitializer::init();
let mut vec_a = MclSparseVec::new(&MclFr::from(3));
let mut vec_b = MclSparseVec::new(&MclFr::from(3));
vec_a.set(&MclFr::from(1), &MclFr::from(92));
vec_b.set(&MclFr::from(1), &MclFr::from(92));
assert_eq!(vec_a, vec_b);
assert_eq!(vec_b, vec_a);
}
#[test]
fn test_not_eq_non_empty() {
MclInitializer::init();
let mut vec_a = MclSparseVec::new(&MclFr::from(3));
let mut vec_b = MclSparseVec::new(&MclFr::from(3));
vec_a.set(&MclFr::from(1), &MclFr::from(92));
vec_b.set(&MclFr::from(1), &MclFr::from(13));
assert_ne!(vec_a, vec_b);
assert_ne!(vec_b, vec_a);
}
#[test]
fn test_iterator() {
MclInitializer::init();
let mut sv = MclSparseVec::new(&MclFr::from(3));
sv.set(&MclFr::from(0), &MclFr::from(1));
sv.set(&MclFr::from(1), &MclFr::from(2));
sv.set(&MclFr::from(2), &MclFr::from(3));
let it = &mut sv.iter();
assert!(&it.next().unwrap() == &MclFr::from(1));
assert!(&it.next().unwrap() == &MclFr::from(2));
assert!(&it.next().unwrap() == &MclFr::from(3));
assert!(it.next() == None);
}
#[test]
fn test_concat() {
MclInitializer::init();
let mut sv1 = MclSparseVec::new(&MclFr::from(2));
sv1.set(&MclFr::from(0), &MclFr::from(1));
sv1.set(&MclFr::from(1), &MclFr::from(2));
let mut sv2 = MclSparseVec::new(&MclFr::from(2));
sv2.set(&MclFr::from(0), &MclFr::from(3));
sv2.set(&MclFr::from(1), &MclFr::from(4));
let sv3 = sv1.concat(&sv2);
assert!(sv3.get(&MclFr::from(0)) == &MclFr::from(1));
assert!(sv3.get(&MclFr::from(1)) == &MclFr::from(2));
assert!(sv3.get(&MclFr::from(2)) == &MclFr::from(3));
assert!(sv3.get(&MclFr::from(3)) == &MclFr::from(4));
}
}

View File

@@ -0,0 +1,11 @@
pub mod mcl_g1;
pub mod mcl_g2;
pub mod mcl_gt;
pub mod mcl_fr;
pub mod mcl_initializer;
pub mod mcl_sparse_matrix;
pub mod mcl_sparse_vec;
pub mod pairing;
pub mod polynomial;
pub mod qap;

View File

@@ -0,0 +1,130 @@
use mcl_rust::*;
use crate::building_block::mcl::{
mcl_g1::MclG1,
mcl_g2::MclG2,
mcl_gt::MclGT,
};
#[derive(Debug, Clone)]
pub struct Pairing;
impl Pairing {
pub fn e(&self, p1: &MclG1, p2: &MclG2) -> MclGT {
let mut v = GT::zero();
pairing(&mut v, &p1.v, &p2.v);
MclGT::from(&v)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_initializer::MclInitializer,
};
fn test(
pairing: &Pairing,
pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT,
p1: &MclG1,
p2: &MclG2,
) -> bool {
let ten = MclFr::from(10);
let ten_p1s = p1 * &ten;
// test e(p1 + ten_p1s, p2) = e(p1, p2) e(ten_p1s, p2)
let lhs = pair(pairing, &(p1 + &ten_p1s), p2);
let rhs1 = pair(pairing, p1, p2);
let rhs2 = pair(pairing, &ten_p1s, p2);
let rhs = rhs1 * rhs2;
lhs == rhs
}
fn test_with_generators(
pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT,
) {
let pairing = &Pairing;
let p1 = MclG1::g();
let p2 = MclG2::g();
let res = test(pairing, pair, &p1, &p2);
assert!(res);
}
fn test_with_random_points(
pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT,
) {
let mut errors = 0;
let num_tests = 1;
for _ in 0..num_tests {
let pairing = &Pairing;
let p1 = MclG1::get_random_point();
let p2 = MclG2::get_random_point();
let res = test(pairing, pair, &p1, &p2);
if res == false {
errors += 1;
}
}
assert!(errors == 0);
}
fn test_plus_to_mul(pair: &dyn Fn(&Pairing, &MclG1, &MclG2) -> MclGT,
) {
let pairing = &Pairing;
let one = &MclG2::g();
let p = &(MclG1::g() + MclG1::g());
let lhs = {
let p_plus_p = p + p;
pair(pairing, &p_plus_p, one)
};
let rhs = {
let a = &pair(pairing, &p, one);
a * a
};
assert!(lhs == rhs);
}
#[test]
fn test_weil_pairing_with_generators() {
MclInitializer::init();
test_with_generators(&Pairing::e);
}
#[test]
fn test_weil_pairing_with_random_points() {
MclInitializer::init();
test_with_random_points(&Pairing::e);
}
#[test]
fn test_tate_pairing_with_test_plus_to_mul() {
MclInitializer::init();
test_plus_to_mul(&Pairing::e);
}
#[test]
fn test_signature_verification() {
MclInitializer::init();
let pairing = &Pairing;
let g1 = &MclG1::g();
let sk = &MclFr::from(2);
let pk = &(g1 * sk);
let m = &b"hamburg steak".to_vec();
let hash_m = &MclG2::hash_and_map(m);
// e(pk, H(m)) = e(g1*sk, H(m)) = e(g1, sk*H(m))
let lhs = pairing.e(pk, hash_m);
let rhs = pairing.e(g1, &(hash_m * sk));
assert!(lhs == rhs);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
pub type SignalId = u128;

View File

@@ -0,0 +1,28 @@
use crate::building_block::mcl::mcl_sparse_vec::MclSparseVec;
#[derive(Clone)]
pub struct Constraint {
pub a: MclSparseVec,
pub b: MclSparseVec,
pub c: MclSparseVec,
}
impl Constraint {
pub fn new(
a: &MclSparseVec,
b: &MclSparseVec,
c: &MclSparseVec,
) -> Self {
let a = a.clone();
let b = b.clone();
let c = c.clone();
Constraint { a, b, c }
}
}
impl std::fmt::Debug for Constraint {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?} . w * {:?} . w - {:?} . w = 0", self.a, self.b, self.c)
}
}

View File

@@ -0,0 +1,686 @@
use nom::{
IResult,
branch::alt,
bytes::complete::tag,
character::complete::{ alpha1, char, multispace0, one_of },
combinator::{ opt, recognize },
multi::{ many0, many1 },
sequence::{ tuple, delimited, terminated },
};
use crate::building_block::mcl::mcl_fr::MclFr;
use crate::zk::w_trusted_setup::qap::config::SignalId;
use std::cell::Cell;
#[derive(Debug, PartialEq, Clone)]
pub enum MathExpr {
Equation(Box<MathExpr>, Box<MathExpr>),
Num(MclFr),
Var(String),
Mul(SignalId, Box<MathExpr>, Box<MathExpr>),
Add(SignalId, Box<MathExpr>, Box<MathExpr>),
Div(SignalId, Box<MathExpr>, Box<MathExpr>),
Sub(SignalId, Box<MathExpr>, Box<MathExpr>),
}
#[derive(Debug)]
pub struct Equation {
pub lhs: MathExpr,
pub rhs: MclFr,
}
pub struct EquationParser();
macro_rules! set_next_id {
($signal_id: expr) => {
$signal_id.set($signal_id.get() + 1);
};
}
impl EquationParser {
fn num_str_to_field_elem(s: &str) -> MclFr {
MclFr::from(s)
}
fn variable<'a>() -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a {
|input| {
let (input, s) =
delimited(
multispace0,
recognize(
terminated(alpha1, many0(one_of("0123456789")))
),
multispace0
)(input)?;
Ok((input, MathExpr::Var(s.to_string())))
}
}
fn decimal() -> impl Fn(&str) -> IResult<&str, MathExpr> {
|input| {
let (input, s) =
delimited(
multispace0,
recognize(
tuple((
opt(char('-')),
many1(
one_of("0123456789")
),
)),
),
multispace0
)(input)?;
let n = EquationParser::num_str_to_field_elem(s);
Ok((input, MathExpr::Num(n)))
}
}
// <term2> ::= <variable> | <number> | '(' <expr> ')'
fn term2<'a>(signal_id: &'a Cell<u128>) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a {
|input| {
let (input, node) = alt((
EquationParser::variable(),
EquationParser::decimal(),
delimited(
delimited(multispace0, char('('), multispace0),
EquationParser::expr(signal_id),
delimited(multispace0, char(')'), multispace0),
),
))(input)?;
Ok((input, node))
}
}
// <term1> ::= <term2> [ ('*'|'/') <term2> ]*
fn term1<'a>(signal_id: &'a Cell<u128>) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a {
|input| {
let rhs = tuple((alt((char('*'), char('/'))), EquationParser::term2(signal_id)));
let (input, (lhs, rhs)) = tuple((
EquationParser::term2(signal_id),
many0(rhs),
))(input)?;
if rhs.len() == 0 {
Ok((input, lhs))
} else {
// translate rhs vector to Mul<Mul<..,Mul>>>..
let rhs_head = &rhs[0];
let rhs = rhs.iter().skip(1).fold(rhs_head.1.clone(), |acc, x| {
match x {
('*', node) => {
set_next_id!(signal_id);
MathExpr::Mul(signal_id.get(), Box::new(acc), Box::new(node.clone()))
},
('/', node) => {
set_next_id!(signal_id);
MathExpr::Div(signal_id.get(), Box::new(acc), Box::new(node.clone()))
},
(op, _) => panic!("unexpected operator encountered in term1 {}", op),
}
});
set_next_id!(signal_id);
let node = if rhs_head.0 == '*' {
MathExpr::Mul(signal_id.get(), Box::new(lhs), Box::new(rhs))
} else {
MathExpr::Div(signal_id.get(), Box::new(lhs), Box::new(rhs))
};
Ok((input, node))
}
}
}
// <expr> ::= <term1> [ ('+'|'-') <term1> ]*
fn expr<'a>(signal_id: &'a Cell<u128>) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a {
|input| {
let rhs = tuple((alt((char('+'), char('-'))), EquationParser::term1(signal_id)));
let (input, (lhs, rhs)) = tuple((
EquationParser::term1(signal_id),
many0(rhs),
))(input)?;
if rhs.len() == 0 {
Ok((input, lhs))
} else {
// translate rhs vector to Add<Add<..,Add>>>..
let rhs_head = &rhs[0];
let rhs = rhs.iter().skip(1).fold(rhs_head.1.clone(), |acc, x| {
match x {
('+', node) => {
set_next_id!(signal_id);
MathExpr::Add(signal_id.get(), Box::new(acc), Box::new(node.clone()))
},
('-', node) => {
set_next_id!(signal_id);
MathExpr::Sub(signal_id.get(), Box::new(acc), Box::new(node.clone()))
},
(op, _) => panic!("unexpected operator encountered in expr: {}", op),
}
});
set_next_id!(signal_id);
let node = if rhs_head.0 == '+' {
MathExpr::Add(signal_id.get(), Box::new(lhs), Box::new(rhs))
} else {
MathExpr::Sub(signal_id.get(), Box::new(lhs), Box::new(rhs))
};
Ok((input, node))
}
}
}
// <equation> ::= <expr> '=' <number>
fn equation<'a>(signal_id: &'a Cell<u128>) -> impl Fn(&str) -> IResult<&str, MathExpr> + 'a {
|input| {
let (input, out) =
tuple((
multispace0,
EquationParser::expr(signal_id),
multispace0,
tag("=="),
multispace0,
EquationParser::decimal(),
multispace0,
))(input)?;
let lhs = out.1;
let rhs = out.5;
Ok((input, MathExpr::Equation(Box::new(lhs), Box::new(rhs))))
}
}
// <term1> ::= <term2> [ ('*'|'/') <term2> ]*
// <term2> ::= <variable> | <number> | '(' <expr> ')'
// <expr> ::= <term1> [ ('+'|'-') <term1> ]*
// <equation> ::= <expr> '==' <number>
pub fn parse<'a>(input: &'a str) -> Result<Equation, String> {
let signal_id = Cell::new(0);
let expr = EquationParser::equation(&signal_id);
match expr(input) {
Ok((_, expr)) => {
match expr {
MathExpr::Equation(lhs, rhs) => {
if let MathExpr::Num(n) = *rhs {
Ok(Equation { lhs: *lhs, rhs: n })
} else {
Err(format!("Equation has unexpected RHS: {:?}", rhs))
}
},
_ => Err(format!("Unexpected parse result: {:?}", expr))
}
},
Err(x) => Err(x.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn test_decimal() {
MclInitializer::init();
match EquationParser::parse("123 == 123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(123)));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_decimal_with_spaces() {
MclInitializer::init();
match EquationParser::parse(" 123 == 123 ") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(123)));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_neg_decimal_below_order() {
MclInitializer::init();
match EquationParser::parse("-123 == -123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(-123)));
assert_eq!(eq.rhs, MclFr::from(-123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_neg_decimal_above_order() {
MclInitializer::init();
match EquationParser::parse("-123 == -123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Num(MclFr::from(-123)));
assert_eq!(eq.rhs, MclFr::from(-123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_1_char_variable() {
MclInitializer::init();
for s in vec!["x", "x1", "x0", "xy", "xy1"] {
match EquationParser::parse(&format!("{} == 123", s)) {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Var(s.to_string()));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
}
#[test]
fn test_1_char_variable_with_spaces() {
MclInitializer::init();
for s in vec!["x", "x1", "x0", "xy", "xy1"] {
match EquationParser::parse(&format!(" {} == 123 ", s)) {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Var(s.to_string()));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
}
#[test]
fn test_simple_add_expr() {
MclInitializer::init();
match EquationParser::parse("123+456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_add_expr_with_1_var() {
MclInitializer::init();
for s in vec!["x", "x1", "x0", "xy", "xy1"] {
match EquationParser::parse(&format!("{}+456==123", s)) {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(1,
Box::new(MathExpr::Var(s.to_string())),
Box::new(MathExpr::Num(MclFr::from(456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
}
#[test]
fn test_simple_add_expr_with_2_vars() {
MclInitializer::init();
for (a,b) in vec![("x", "y"), ("x1", "y1"), ("xxx1123", "yyy123443")] {
match EquationParser::parse(&format!("{}+{}==123", a, b)) {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(1,
Box::new(MathExpr::Var(a.to_string())),
Box::new(MathExpr::Var(b.to_string())),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
}
#[test]
fn test_simple_add_expr_incl_neg() {
MclInitializer::init();
match EquationParser::parse("123+-456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(-456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_sub_expr() {
MclInitializer::init();
match EquationParser::parse("123-456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Sub(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_sub_expr_1_var() {
MclInitializer::init();
for s in vec!["x", "x1", "x0", "xy", "xy1"] {
match EquationParser::parse(&format!("123-{}==123", s)) {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Sub(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Var(s.to_string())),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
}
#[test]
fn test_simple_sub_expr_incl_neg1() {
MclInitializer::init();
match EquationParser::parse("-123-456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Sub(1,
Box::new(MathExpr::Num(MclFr::from(-123))),
Box::new(MathExpr::Num(MclFr::from(456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_sub_expr_incl_neg1_1_var() {
MclInitializer::init();
for s in vec!["x", "x1", "x0", "xy", "xy1"] {
match EquationParser::parse(&format!("-123-{}==123", s)) {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Sub(1,
Box::new(MathExpr::Num(MclFr::from(-123))),
Box::new(MathExpr::Var(s.to_string())),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
}
#[test]
fn test_simple_sub_expr_incl_neg2() {
MclInitializer::init();
match EquationParser::parse("123--456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Sub(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(-456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_sub_expr_incl_neg2_with_spaces() {
MclInitializer::init();
match EquationParser::parse("123 - -456 == 123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Sub(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(-456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_sub_expr_incl_neg2_with_spaces_1_var() {
MclInitializer::init();
match EquationParser::parse("x - -456 == 123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Sub(1,
Box::new(MathExpr::Var("x".to_string())),
Box::new(MathExpr::Num(MclFr::from(-456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_mul_expr() {
MclInitializer::init();
match EquationParser::parse("123*456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Mul(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_mul_expr_incl_neg1() {
MclInitializer::init();
match EquationParser::parse("123*-456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Mul(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(-456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_mul_expr_incl_neg2() {
MclInitializer::init();
match EquationParser::parse("-123*456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Mul(1,
Box::new(MathExpr::Num(MclFr::from(-123))),
Box::new(MathExpr::Num(MclFr::from(456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_mul_expr_incl_neg() {
MclInitializer::init();
match EquationParser::parse("123*-456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Mul(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(-456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_simple_div_expr() {
MclInitializer::init();
match EquationParser::parse("123/456==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Div(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(456))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_add_and_mul_expr() {
MclInitializer::init();
match EquationParser::parse("123+456*789==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(2,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Mul(1,
Box::new(MathExpr::Num(MclFr::from(456))),
Box::new(MathExpr::Num(MclFr::from(789))),
)),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_add_mul_div_expr() {
MclInitializer::init();
match EquationParser::parse("111/222+333*444==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(3,
Box::new(MathExpr::Div(1,
Box::new(MathExpr::Num(MclFr::from(111))),
Box::new(MathExpr::Num(MclFr::from(222))),
)),
Box::new(MathExpr::Mul(2,
Box::new(MathExpr::Num(MclFr::from(333))),
Box::new(MathExpr::Num(MclFr::from(444))),
)),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_paren_add_and_mul_expr() {
MclInitializer::init();
match EquationParser::parse("(123+456)*789==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Mul(2,
Box::new(MathExpr::Add(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(456))),
)),
Box::new(MathExpr::Num(MclFr::from(789))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_paren_add_and_mul_expr_with_spaces() {
MclInitializer::init();
match EquationParser::parse(" (123 + 456) * 789 == 123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Mul(2,
Box::new(MathExpr::Add(1,
Box::new(MathExpr::Num(MclFr::from(123))),
Box::new(MathExpr::Num(MclFr::from(456))),
)),
Box::new(MathExpr::Num(MclFr::from(789))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_paren_add_mul_sub_expr() {
MclInitializer::init();
match EquationParser::parse("(111+222)*(333-444)==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Mul(3,
Box::new(MathExpr::Add(1,
Box::new(MathExpr::Num(MclFr::from(111))),
Box::new(MathExpr::Num(MclFr::from(222))),
)),
Box::new(MathExpr::Sub(2,
Box::new(MathExpr::Num(MclFr::from(333))),
Box::new(MathExpr::Num(MclFr::from(444))),
)),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_multiple_paren() {
MclInitializer::init();
match EquationParser::parse("((111+222))==123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(1,
Box::new(MathExpr::Num(MclFr::from(111))),
Box::new(MathExpr::Num(MclFr::from(222))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn test_multiple_paren_with_spaces() {
MclInitializer::init();
match EquationParser::parse(" ( (111+222) ) == 123") {
Ok(eq) => {
assert_eq!(eq.lhs, MathExpr::Add(1,
Box::new(MathExpr::Num(MclFr::from(111))),
Box::new(MathExpr::Num(MclFr::from(222))),
));
assert_eq!(eq.rhs, MclFr::from(123));
},
Err(_) => panic!(),
}
}
#[test]
fn blog_post_1_example_1() {
MclInitializer::init();
let expr = "(x * x * x) + x + 5 == 35";
match EquationParser::parse(expr) {
Ok(eq) => {
println!("{} -> {:?}", expr, eq);
},
Err(_) => panic!(),
}
}
}

View File

@@ -0,0 +1,251 @@
use crate::building_block::mcl::qap::{
equation_parser::{
Equation,
MathExpr,
},
term::Term,
};
pub struct Gate {
pub a: Term,
pub b: Term,
pub c: Term,
}
impl std::fmt::Debug for Gate {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?} * {:?} = {:?}", self.a, self.b, self.c)
}
}
impl Gate {
// traverse the Equation tree generating statement at each Add/Mul node
fn traverse_lhs(
expr: &MathExpr, gates: &mut Vec<Gate>
) -> Term {
match expr {
MathExpr::Num(n) => Term::Num(n.clone()),
MathExpr::Var(s) => Term::Var(s.clone()),
MathExpr::Add(signal_id, left, right) => {
let a = Gate::traverse_lhs(left, gates);
let b = Gate::traverse_lhs(right, gates);
let c = Term::TmpVar(*signal_id);
// a + b = c
// -> (a + b) * 1 = c
let sum = Term::Sum(Box::new(a), Box::new(b));
gates.push(Gate { a: sum, b: Term::One, c: c.clone() });
c
},
MathExpr::Mul(signal_id, left, right) => {
let a = Gate::traverse_lhs(left, gates);
let b = Gate::traverse_lhs(right, gates);
let c = Term::TmpVar(*signal_id);
gates.push(Gate { a, b, c: c.clone() });
c
},
MathExpr::Sub(signal_id, left, right) => {
let a = Gate::traverse_lhs(left, gates);
let b = Gate::traverse_lhs(right, gates);
let c = Term::TmpVar(*signal_id);
// a - b = c
// -> b + c = a
// -> (b + c) * 1 = a
let sum = Term::Sum(Box::new(b), Box::new(c.clone()));
gates.push(Gate { a: sum, b: Term::One, c: a.clone() });
c
},
MathExpr::Div(signal_id, left, right) => {
let a = Gate::traverse_lhs(left, gates);
let b = Gate::traverse_lhs(right, gates);
let c = Term::TmpVar(*signal_id);
// a / b = c
// -> b * c = a
gates.push(Gate { a: b, b: c.clone(), c: a.clone() });
// send c to next as the original division does
c
},
MathExpr::Equation(_lhs, _rhs) => {
panic!("should not be visited");
}
}
}
pub fn build(eq: &Equation) -> Vec<Gate> {
let mut gates: Vec<Gate> = vec![];
let root = Gate::traverse_lhs(&eq.lhs, &mut gates);
let out_gate = Gate { a: root, b: Term::One, c: Term::Out };
gates.push(out_gate);
gates
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_initializer::MclInitializer,
qap::equation_parser::EquationParser,
};
#[test]
fn test_build_add() {
MclInitializer::init();
let input = "x + 4 == 9";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
println!("{:?}", gates);
assert_eq!(gates.len(), 2);
// t1 = (x + 4) * 1
assert_eq!(gates[0].a, Term::Sum(Box::new(Term::Var("x".to_string())), Box::new(Term::Num(MclFr::from(4)))));
assert_eq!(gates[0].b, Term::One);
assert_eq!(gates[0].c, Term::TmpVar(1));
// out = t1 * 1
assert_eq!(gates[1].a, Term::TmpVar(1));
assert_eq!(gates[1].b, Term::One);
assert_eq!(gates[1].c, Term::Out);
}
#[test]
fn test_build_sub() {
MclInitializer::init();
let input = "x - 4 == 9";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
assert_eq!(gates.len(), 2);
// t1 = (x + 4) * 1
assert_eq!(gates[0].a, Term::Sum(Box::new(Term::Num(MclFr::from(4))), Box::new(Term::TmpVar(1))));
assert_eq!(gates[0].b, Term::One);
assert_eq!(gates[0].c, Term::Var("x".to_string()));
// out = t1 * 1
assert_eq!(gates[1].a, Term::TmpVar(1));
assert_eq!(gates[1].b, Term::One);
assert_eq!(gates[1].c, Term::Out);
}
#[test]
fn test_build_mul() {
MclInitializer::init();
let input = "x * 4 == 9";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
assert_eq!(gates.len(), 2);
// x = (4 + t1) * 1
assert_eq!(gates[0].a, Term::Var("x".to_string()));
assert_eq!(gates[0].b, Term::Num(MclFr::from(4)));
assert_eq!(gates[0].c, Term::TmpVar(1));
// out = t1 * 1
assert_eq!(gates[1].a, Term::TmpVar(1));
assert_eq!(gates[1].c, Term::Out);
}
#[test]
fn test_build_div() {
MclInitializer::init();
let input = "x / 4 == 2";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
assert_eq!(gates.len(), 2);
// x = 4 * t1
assert_eq!(gates[0].a, Term::Num(MclFr::from(4)));
assert_eq!(gates[0].b, Term::TmpVar(1));
assert_eq!(gates[0].c, Term::Var("x".to_string()));
// out = t1 * 1
assert_eq!(gates[1].a, Term::TmpVar(1));
assert_eq!(gates[1].c, Term::Out);
}
#[test]
fn test_build_combined1() {
MclInitializer::init();
let input = "(3 * x + 4) / 2 == 11";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
assert_eq!(gates.len(), 4);
// t1 = 3 * x
assert_eq!(gates[0].a, Term::Num(MclFr::from(3)));
assert_eq!(gates[0].b, Term::Var("x".to_string()));
assert_eq!(gates[0].c, Term::TmpVar(1));
// t2 = (t1 + 4) * 1
assert_eq!(gates[1].a, Term::Sum(
Box::new(Term::TmpVar(1)),
Box::new(Term::Num(MclFr::from(4)))
));
assert_eq!(gates[1].b, Term::One);
assert_eq!(gates[1].c, Term::TmpVar(2));
// t2 = 2 * t3
assert_eq!(gates[2].a, Term::Num(MclFr::from(2)));
assert_eq!(gates[2].b, Term::TmpVar(3));
assert_eq!(gates[2].c, Term::TmpVar(2));
// out = t3 * 1
assert_eq!(gates[3].a, Term::TmpVar(3));
assert_eq!(gates[3].b, Term::One);
assert_eq!(gates[3].c, Term::Out);
}
#[test]
fn test_build_combined2() {
MclInitializer::init();
let input = "(x * x * x) + x + 5 == 35";
println!("Equation: {}", input);
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
println!("Gates: {:?}", gates);
assert_eq!(gates.len(), 5);
// t1 = x * x
assert_eq!(gates[0].a, Term::Var("x".to_string()));
assert_eq!(gates[0].b, Term::Var("x".to_string()));
assert_eq!(gates[0].c, Term::TmpVar(1));
// t2 = x * t1
assert_eq!(gates[1].a, Term::Var("x".to_string()));
assert_eq!(gates[1].b, Term::TmpVar(1));
assert_eq!(gates[1].c, Term::TmpVar(2));
// t3 = (x + 5) * 1
assert_eq!(gates[2].a, Term::Sum(
Box::new(Term::Var("x".to_string())),
Box::new(Term::Num(MclFr::from(5)))
));
assert_eq!(gates[2].b, Term::One);
assert_eq!(gates[2].c, Term::TmpVar(3));
// t4 = (t2 + t3) * 1
assert_eq!(gates[3].a, Term::Sum(
Box::new(Term::TmpVar(2)),
Box::new(Term::TmpVar(3))
));
assert_eq!(gates[3].b, Term::One);
assert_eq!(gates[3].c, Term::TmpVar(4));
// out = t4 * 1
assert_eq!(gates[4].a, Term::TmpVar(4));
assert_eq!(gates[4].b, Term::One);
assert_eq!(gates[4].c, Term::Out);
}
#[test]
fn blog_post_1_example_1() {
MclInitializer::init();
let expr = "(x * x * x) + x + 5 == 35";
let eq = EquationParser::parse(expr).unwrap();
let gates = &Gate::build(&eq);
println!("{:?}", gates);
}
}

View File

@@ -0,0 +1,131 @@
use crate::zk::w_trusted_setup::qap::gates::bool_circuit::{BoolCircuit, Processor};
pub struct HalfAdder();
#[derive(Debug)]
pub struct AdderResult {
pub sum: bool,
pub carry: bool,
}
impl HalfAdder {
// (augend, addend) -> (sum, carry)
pub fn add(augend: bool, addend: bool) -> AdderResult {
let sum = BoolCircuit::Xor(
Box::new(BoolCircuit::Leaf(augend)),
Box::new(BoolCircuit::Leaf(addend)),
);
let carry = BoolCircuit::And(
Box::new(BoolCircuit::Leaf(augend)),
Box::new(BoolCircuit::Leaf(addend)),
);
let sum = Processor::eval(&sum);
let carry = Processor::eval(&carry);
AdderResult { sum, carry }
}
}
pub struct FullAdder();
impl FullAdder {
pub fn add(augend: bool, addend: bool, carry: bool) -> AdderResult {
let res1 = HalfAdder::add(augend, addend);
let res2 = HalfAdder::add(res1.sum, carry);
let carry = BoolCircuit::Or(
Box::new(BoolCircuit::Leaf(res1.carry)),
Box::new(BoolCircuit::Leaf(res2.carry)),
);
let carry = Processor::eval(&carry);
AdderResult { sum: res2.sum, carry }
}
}
#[cfg(test)]
mod half_adder_tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn add_0_0() {
MclInitializer::init();
let res = HalfAdder::add(false, false);
assert_eq!(res.sum, false);
assert_eq!(res.carry, false);
}
#[test]
fn add_1_0_or_0_1() {
MclInitializer::init();
let res = HalfAdder::add(true, false);
assert_eq!(res.sum, true);
assert_eq!(res.carry, false);
let res = HalfAdder::add(false, true);
assert_eq!(res.sum, true);
assert_eq!(res.carry, false);
}
#[test]
fn add_1_1() {
MclInitializer::init();
let res = HalfAdder::add(true, true);
assert_eq!(res.sum, false);
assert_eq!(res.carry, true);
}
}
#[cfg(test)]
mod full_adder_tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn single_inst_add_0_0_0() {
MclInitializer::init();
let res = FullAdder::add(false, false, false);
assert_eq!(res.sum, false);
assert_eq!(res.carry, false);
}
#[test]
fn single_inst_add_1_0_0_or_0_1_0_or_0_0_1() {
MclInitializer::init();
let res = FullAdder::add(true, false, false);
assert_eq!(res.sum, true);
assert_eq!(res.carry, false);
let res = FullAdder::add(false, true, false);
assert_eq!(res.sum, true);
assert_eq!(res.carry, false);
let res = FullAdder::add(false, false, true);
assert_eq!(res.sum, true);
assert_eq!(res.carry, false);
}
#[test]
fn single_inst_add_1_1_0_or_1_0_1_or_0_1_1() {
MclInitializer::init();
let res = FullAdder::add(true, true, false);
assert_eq!(res.sum, false);
assert_eq!(res.carry, true);
let res = FullAdder::add(true, false, true);
assert_eq!(res.sum, false);
assert_eq!(res.carry, true);
let res = FullAdder::add(false, true, true);
assert_eq!(res.sum, false);
assert_eq!(res.carry, true);
}
#[test]
fn single_inst_add_1_1_1() {
MclInitializer::init();
let res = FullAdder::add(true, true, true);
assert_eq!(res.sum, true);
assert_eq!(res.carry, true);
}
}

View File

@@ -0,0 +1,12 @@
use crate::building_block::mcl::mcl_fr::MclFr;
#[derive(Debug, PartialEq, Clone)]
pub enum ArithCircuit {
Leaf(MclFr),
Mul(Box<ArithCircuit>, Box<ArithCircuit>),
Add(Box<ArithCircuit>, Box<ArithCircuit>),
Sub(Box<ArithCircuit>, Box<ArithCircuit>),
Div(Box<ArithCircuit>, Box<ArithCircuit>),
}
pub struct Processor();

View File

@@ -0,0 +1,76 @@
use crate::building_block::mcl::{
mcl_fr::MclFr,
qap::gates::arith_circuit::ArithCircuit,
};
#[derive(Debug, PartialEq, Clone)]
pub enum BoolCircuit {
Leaf(bool),
And(Box<BoolCircuit>, Box<BoolCircuit>),
Xor(Box<BoolCircuit>, Box<BoolCircuit>),
Or(Box<BoolCircuit>, Box<BoolCircuit>),
}
pub struct Processor();
impl Processor {
pub fn eval(root: &BoolCircuit) -> bool {
match root {
BoolCircuit::Leaf(x) => *x,
BoolCircuit::And(a, b) => Processor::eval(&a) && Processor::eval(&b),
BoolCircuit::Xor(a, b) => {
let a = Processor::eval(&a);
let b = Processor::eval(&b);
!(a && b) && (a || b)
}
BoolCircuit::Or(a, b) => Processor::eval(&a) || Processor::eval(&b),
}
}
pub fn to_arith_circuit(root: BoolCircuit) -> ArithCircuit {
match root {
BoolCircuit::Leaf(x) => ArithCircuit::Leaf(MclFr::from(x)),
BoolCircuit::And(a, b) => {
let a = Processor::eval(&a);
let b = Processor::eval(&b);
let a = ArithCircuit::Leaf(MclFr::from(a));
let b = ArithCircuit::Leaf(MclFr::from(b));
// AND(a, b) = ab
ArithCircuit::Mul(Box::new(a), Box::new(b))
},
BoolCircuit::Xor(a, b) => {
let a = Processor::eval(&a);
let b = Processor::eval(&b);
let a = ArithCircuit::Leaf(MclFr::from(a));
let b = ArithCircuit::Leaf(MclFr::from(b));
// XOR(a, b) = (a + b) - 2 ab
let t1 = ArithCircuit::Add(
Box::new(a.clone()),
Box::new(b.clone()),
);
let two = ArithCircuit::Leaf(MclFr::from(2));
let t2 = ArithCircuit::Mul(Box::new(a), Box::new(b));
let t2 = ArithCircuit::Mul(Box::new(two), Box::new(t2));
ArithCircuit::Add(
Box::new(t1),
Box::new(t2),
)
},
BoolCircuit::Or(a, b) => {
let a = Processor::eval(&a);
let b = Processor::eval(&b);
let a = ArithCircuit::Leaf(MclFr::from(a));
let b = ArithCircuit::Leaf(MclFr::from(b));
// Or(a, b) = a + b - a * b
let t1 = ArithCircuit::Add(Box::new(a.clone()), Box::new(b.clone()));
let t2 = ArithCircuit::Mul(Box::new(a.clone()), Box::new(b.clone()));
ArithCircuit::Sub(
Box::new(t1),
Box::new(t2),
)
},
}
}
}

View File

@@ -0,0 +1,4 @@
pub mod arith_circuit;
pub mod bool_circuit;
pub mod number;
pub mod adder;

View File

@@ -0,0 +1,119 @@
#[derive(Debug)]
pub struct Number {
pub bits: [bool; 64],
}
impl Number {
pub fn new(n: i64) -> Self {
let mut bits = [false; 64];
if n == 0 {
Number { bits }
} else {
let mut m = n;
if n < 0 {
// convert to a positive number w/ the same bit representation
m = i64::MAX + n + 1;
}
let mut x = m;
let mut i = 0;
while x > 0 {
if x & 1 == 1 {
bits[i] = true;
}
i += 1;
x >>= 1;
}
if n < 0 { // set sign bit if originally a negative value
bits[63] = true;
}
Number { bits }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::mcl_initializer::MclInitializer;
#[test]
fn zero() {
MclInitializer::init();
let x = Number::new(0);
for i in 0..64 {
assert_eq!(x.bits[i], false);
}
}
#[test]
fn pos_1() {
MclInitializer::init();
let x = Number::new(1);
for i in 0..64 {
let exp = i == 0;
assert_eq!(x.bits[i], exp);
}
}
#[test]
fn pos_5() {
MclInitializer::init();
let x = Number::new(5);
assert_eq!(x.bits[0], true);
assert_eq!(x.bits[1], false);
assert_eq!(x.bits[2], true);
for i in 3..64 {
assert_eq!(x.bits[i], false);
}
}
#[test]
fn pos_7() {
MclInitializer::init();
let x = Number::new(7);
assert_eq!(x.bits[0], true);
assert_eq!(x.bits[1], true);
assert_eq!(x.bits[2], true);
for i in 3..64 {
assert_eq!(x.bits[i], false);
}
}
#[test]
fn neg_1() {
MclInitializer::init();
let x = Number::new(-1);
for i in 0..64 {
assert_eq!(x.bits[i], true);
}
}
#[test]
fn neg_5() {
MclInitializer::init();
let x = Number::new(-5);
assert_eq!(x.bits[0], true);
assert_eq!(x.bits[1], true);
assert_eq!(x.bits[2], false);
for i in 3..64 {
assert_eq!(x.bits[i], true);
}
}
#[test]
fn neg_7() {
MclInitializer::init();
let x = Number::new(-7);
assert_eq!(x.bits[0], true);
assert_eq!(x.bits[1], false);
assert_eq!(x.bits[2], false);
for i in 3..64 {
assert_eq!(x.bits[i], true);
}
}
}

View File

@@ -0,0 +1,9 @@
pub mod config;
pub mod constraint;
pub mod equation_parser;
pub mod gate;
pub mod gates;
pub mod qap;
pub mod r1cs;
pub mod r1cs_tmpl;
pub mod term;

View File

@@ -0,0 +1,326 @@
use std::ops::Mul;
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_sparse_vec::MclSparseVec,
polynomial::{
DivResult,
Polynomial,
},
qap::r1cs::R1CS,
};
use num_traits::Zero;
#[derive(Clone)]
pub struct QAP {
pub vi: Vec<Polynomial>,
pub wi: Vec<Polynomial>,
pub yi: Vec<Polynomial>,
pub num_constraints: MclFr,
}
impl QAP {
// build a polynomial that evaluates to target_val at x == index
// and zero for x != index for each target value.
// e.g.
// (x - 2) * (x - 3) * 3 / ((1 - 2) * (1 - 3))
// where x in [1, 2, 3]; evaluates to 3 if x == 1 and 0 if x != 1
fn build_polynomial(target_vals: &MclSparseVec) -> Polynomial {
let mut target_val_polys = vec![];
let one = MclFr::from(1);
let mut target_x = MclFr::from(1);
while target_x <= target_vals.size {
let target_val = target_vals.get(&(&target_x - &one));
// if target val is zero, simply add 0x^0
if target_val.is_zero() {
target_val_polys.push(Polynomial::new(&vec![MclFr::from(0)]));
target_x.inc();
continue;
}
let mut numerator_polys = vec![
Polynomial::new(&vec![target_val.clone()]),
];
let mut denominator = MclFr::from(1);
let mut i = MclFr::from(1);
while &i <= &target_vals.size {
if &i == &target_x {
i.inc();
continue;
}
// (x - i) to let the polynomal evaluate to zero at x = i
let numerator_poly = Polynomial::new(&vec![
-&i,
MclFr::from(1),
]);
numerator_polys.push(numerator_poly);
// (target_idx - i) to cancel out the corresponding
// numerator_poly at x = target_idx
denominator = &denominator * (&target_x - &i);
i.inc();
}
// merge denominator polynomial to numerator polynomial vector
let denominator_poly = Polynomial::new(&vec![denominator.inv()]);
let mut polys = numerator_polys;
polys.push(denominator_poly);
// aggregate numerator polynomial vector
let mut acc_poly = Polynomial::new(&vec![MclFr::from(1)]);
for poly in polys {
acc_poly = acc_poly.mul(&poly);
}
target_val_polys.push(acc_poly);
target_x.inc();
}
// aggregate polynomials for all target values
let mut res = target_val_polys[0].clone();
for x in &target_val_polys[1..] {
res += x;
}
res
}
pub fn build_p(&self, witness: &MclSparseVec) -> Polynomial {
let zero = &Polynomial::zero();
let (mut v, mut w, mut y) =
(zero.clone(), zero.clone(), zero.clone());
for i in 0..witness.size.to_usize() {
let wit = &witness[&MclFr::from(i)];
v += &self.vi[i] * wit;
w += &self.wi[i] * wit;
y += &self.yi[i] * wit;
};
(v * &w) - &y
}
// build polynomial (x-1)(x-2)..(x-num_constraints)
pub fn build_t(num_constraints: &MclFr) -> Polynomial {
let mut i = MclFr::from(1);
let mut polys = vec![];
// create (x-i) polynomials
while &i <= &num_constraints {
let poly = Polynomial::new(&vec![
-&i,
MclFr::from(1),
]);
polys.push(poly);
i.inc();
}
// aggregate (x-i) polynomial into a single polynomial
let mut acc_poly = Polynomial::new(&vec![MclFr::from(1)]);
for poly in polys {
acc_poly = acc_poly.mul(&poly);
}
acc_poly
}
pub fn build(r1cs: &R1CS) -> QAP {
/*
c1 c2 c3 (coeffs for a1, a2, a3)
for w=( 1, 2, 3),
at x=1, 3 * 1 (w[3] * a3)
at x=2, 2 * 3 (w[1] * a1)
at x=3, 0
at x=4, 2 * 2 (w[2] * a2)
*/
// w1 w2 w3 <- witness e.g. w=(x,y,z,w1,...)
// x=1 | 0 0 1 |
// x=2 | 3 0 0 |
// x=3 | 0 0 0 |
// x=4 | 0 2 0 |
// ^
// +-- constraints
let constraints = r1cs.to_constraint_matrices();
// x=1 2 3 4 <- constraints
// w1 [0 3 0 0]
// w2 [0 0 0 2]
// w3 [1 0 0 0]
// ^
// +-- witness
let constraints_v_t = constraints.a.transpose();
let constraints_w_t = constraints.b.transpose();
let constraints_y_t = constraints.c.transpose();
println!("- const v_t\n{:?}", &constraints_v_t);
println!("- const w_t\n{:?}", &constraints_w_t);
println!("- const y_t\n{:?}", &constraints_y_t);
// build polynomials for each wirness variable
// e.g. vi[0] is a polynomial for the first witness variable
// and returns 3 at x=2 and 0 at all other x's
let mut vi = vec![];
let mut wi = vec![];
let mut yi = vec![];
let mut i = MclFr::from(0);
let num_witness_values = &r1cs.witness.size;
while &i < num_witness_values {
// extract a constraint row
// x = 1 2 3 4
// wi = [0 3 0 0]
let v_row = constraints_v_t.get_row(&i);
let w_row = constraints_w_t.get_row(&i);
let y_row = constraints_y_t.get_row(&i);
// convert a constraint row to a polynomial
let v_poly = QAP::build_polynomial(&v_row);
let w_poly = QAP::build_polynomial(&w_row);
let y_poly = QAP::build_polynomial(&y_row);
vi.push(v_poly);
wi.push(w_poly);
yi.push(y_poly);
i.inc();
}
let num_constraints = constraints.a.height.clone();
QAP { vi, wi, yi, num_constraints }
}
pub fn is_valid(
&self,
witness: &MclSparseVec,
num_constraints: &MclFr,
) -> bool {
let t = QAP::build_t(num_constraints);
let p = self.build_p(witness);
match p.divide_by(&t) {
DivResult::Quotient(_) => true,
DivResult::QuotientRemainder(_) => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_initializer::MclInitializer,
qap::constraint::Constraint,
};
#[test]
fn test_r1cs_to_polynomial() {
MclInitializer::init();
// x out t1 y t2
// 0 1 2 3 4 5
// [1, 3, 35, 9, 27, 30]
let witness = MclSparseVec::from(&vec![
MclFr::from(1),
MclFr::from(3),
MclFr::from(35),
MclFr::from(9),
MclFr::from(27),
MclFr::from(30),
]);
let witness_size = &witness.size;
// A
// 0 1 2 3 4 5
// [0, 1, 0, 0, 0, 0]
// [0, 0, 0, 1, 0, 0]
// [0, 1, 0, 0, 1, 0]
// [5, 0, 0, 0, 0, 1]
let mut a1 = MclSparseVec::new(witness_size);
a1.set(&MclFr::from(1), &MclFr::from(1));
let mut a2 = MclSparseVec::new(witness_size);
a2.set(&MclFr::from(3), &MclFr::from(1));
let mut a3 = MclSparseVec::new(witness_size);
a3.set(&MclFr::from(1), &MclFr::from(1));
a3.set(&MclFr::from(4), &MclFr::from(1));
let mut a4 = MclSparseVec::new(witness_size);
a4.set(&MclFr::from(0), &MclFr::from(5));
a4.set(&MclFr::from(5), &MclFr::from(1));
// B
// 0 1 2 3 4 5
// [0, 1, 0, 0, 0, 0]
// [0, 1, 0, 0, 0, 0]
// [1, 0, 0, 0, 0, 0]
// [1, 0, 0, 0, 0, 0]
let mut b1 = MclSparseVec::new(witness_size);
b1.set(&MclFr::from(1), &MclFr::from(1));
let mut b2 = MclSparseVec::new(witness_size);
b2.set(&MclFr::from(1), &MclFr::from(1));
let mut b3 = MclSparseVec::new(witness_size);
b3.set(&MclFr::from(0), &MclFr::from(1));
let mut b4 = MclSparseVec::new(witness_size);
b4.set(&MclFr::from(0), &MclFr::from(1));
// C
// 0 1 2 3 4 5
// [0, 0, 0, 1, 0, 0]
// [0, 0, 0, 0, 1, 0]
// [0, 0, 0, 0, 0, 1]
// [0, 0, 1, 0, 0, 0]
let mut c1 = MclSparseVec::new(witness_size);
c1.set(&MclFr::from(3), &MclFr::from(1));
let mut c2 = MclSparseVec::new(witness_size);
c2.set(&MclFr::from(4), &MclFr::from(1));
let mut c3 = MclSparseVec::new(witness_size);
c3.set(&MclFr::from(5), &MclFr::from(1));
let mut c4 = MclSparseVec::new(witness_size);
c4.set(&MclFr::from(2), &MclFr::from(1));
let constraints = vec![
Constraint::new(&a1, &b1, &c1),
Constraint::new(&a2, &b2, &c2),
Constraint::new(&a3, &b3, &c3),
Constraint::new(&a4, &b4, &c4),
];
let num_constraints = &MclFr::from(constraints.len());
let mid_beg = MclFr::from(3);
let r1cs = R1CS {
constraints,
witness: witness.clone(),
mid_beg,
};
let qap = QAP::build(&r1cs);
let is_passed = qap.is_valid(&witness, num_constraints);
assert!(is_passed);
}
#[test]
fn test_build_t() {
MclInitializer::init();
let one = &MclFr::from(1);
let two = &MclFr::from(2);
let neg_three = &-MclFr::from(3);
// (x-1)(x-2) = x^2 - 3x + 2
let z = QAP::build_t(two);
// expect [2, -3, 1]
assert_eq!(z.len(), 3);
assert_eq!(&z[0], two);
assert_eq!(&z[1], neg_three);
assert_eq!(&z[2], one);
}
}

View File

@@ -0,0 +1,159 @@
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_sparse_vec::MclSparseVec,
mcl_sparse_matrix::MclSparseMatrix,
qap::{
constraint::Constraint,
r1cs_tmpl::R1CSTmpl,
term::Term,
},
};
use std::collections::HashMap;
use num_traits::Zero;
#[derive(Clone)]
pub struct R1CS {
pub constraints: Vec<Constraint>,
pub witness: MclSparseVec,
pub mid_beg: MclFr,
}
// matrix representing a constraint whose
// row is the multiples of each witness value i.e.
// a = [a1, a2, ...]
// *
// b = [b1, b2, ...]
// ||
// c = [c1, c2, ...]
#[derive(Debug)]
pub struct ConstraintMatrices {
pub a: MclSparseMatrix,
pub b: MclSparseMatrix,
pub c: MclSparseMatrix,
}
impl R1CS {
// build the witness vector that is in the order expected by the prover
// while validating the witness vector with the witness_instance
fn build_witness_vec(
tmpl: &R1CSTmpl,
witness_instance: &HashMap<Term, MclFr>,
) -> Result<(MclSparseVec, MclFr), String> {
// build witness vector with values assigned
let mut witness = MclSparseVec::new(&MclFr::from(tmpl.witness.len()));
let mut i = MclFr::zero();
for term in tmpl.witness.iter() {
if !witness_instance.contains_key(&term) {
return Err(format!("'{:?}' is missing in witness_instance", term));
}
witness[&i] = witness_instance.get(&term).unwrap().clone();
i.inc();
}
Ok((witness, tmpl.mid_beg.clone()))
}
// evaluate all constraints and confirm they all hold
pub fn validate(&self) -> Result<(), String> {
for constraint in &self.constraints {
let a = &(&constraint.a * &self.witness).sum();
let b = &(&constraint.b * &self.witness).sum();
let c = &(&constraint.c * &self.witness).sum();
println!("r1cs: {:?} * {:?} = {:?}", &a, &b, &c);
if &(a * b) != c {
return Err(format!("Constraint a ({:?}) * b ({:?}) = c ({:?}) doesn't hold", a, b, c));
}
}
println!("");
Ok(())
}
pub fn from_tmpl(
tmpl: &R1CSTmpl,
witness: &HashMap<Term, MclFr>,
) -> Result<R1CS, String> {
let (witness, mid_beg) = R1CS::build_witness_vec(tmpl, witness)?;
let r1cs = R1CS {
constraints: tmpl.constraints.clone(),
witness,
mid_beg,
};
Ok(r1cs)
}
pub fn to_constraint_by_witness_matrices(&self) -> ConstraintMatrices {
let mut a = vec![];
let mut b = vec![];
let mut c = vec![];
for constraint in &self.constraints {
a.push(&constraint.a * &self.witness);
b.push(&constraint.b * &self.witness);
c.push(&constraint.c * &self.witness);
}
ConstraintMatrices {
a: MclSparseMatrix::from(&a),
b: MclSparseMatrix::from(&b),
c: MclSparseMatrix::from(&c),
}
}
pub fn to_constraint_matrices(&self) -> ConstraintMatrices {
let mut a = vec![];
let mut b = vec![];
let mut c = vec![];
for constraint in &self.constraints {
a.push(constraint.a.clone());
b.push(constraint.b.clone());
c.push(constraint.c.clone());
}
ConstraintMatrices {
a: MclSparseMatrix::from(&a),
b: MclSparseMatrix::from(&b),
c: MclSparseMatrix::from(&c),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::{
mcl_initializer::MclInitializer,
qap::{
equation_parser::EquationParser,
gate::Gate,
r1cs_tmpl::R1CSTmpl,
term::Term,
},
};
#[test]
fn test_validate() {
MclInitializer::init();
let input = "(x + 2) + 4 * y == 21";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
let tmpl = &R1CSTmpl::new(gates);
let witness = HashMap::<Term, MclFr>::from([
(Term::One, MclFr::from(1)),
(Term::Var("x".to_string()), MclFr::from(3)),
(Term::Var("y".to_string()), MclFr::from(4)),
(Term::Out, eq.rhs),
(Term::TmpVar(1), MclFr::from(5)),
(Term::TmpVar(2), MclFr::from(16)),
(Term::TmpVar(3), MclFr::from(21)),
]);
let r1cs = R1CS::from_tmpl(tmpl, &witness).unwrap();
r1cs.validate().unwrap();
}
}

View File

@@ -0,0 +1,421 @@
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_sparse_vec::MclSparseVec,
qap::{
term::Term,
gate::Gate,
constraint::Constraint,
},
};
use std::collections::HashMap;
use num_traits::Zero;
pub struct R1CSTmpl {
pub constraints: Vec<Constraint>,
pub witness: Vec<Term>,
pub indices: HashMap<Term, MclFr>,
pub mid_beg: MclFr,
}
impl R1CSTmpl {
// build witness vector whose elements in the following order:
// 1, inputs, Out, mid
fn build_witness(
inputs: &Vec<Term>,
mid: &Vec<Term>,
witness: &mut Vec<Term>,
indices: &mut HashMap::<Term, MclFr>,
) -> MclFr {
let mut i = MclFr::from(1); // `1` has already been added in new function
for x in inputs {
witness.push(x.clone());
indices.insert(x.clone(), i.clone());
i.inc();
}
witness.push(Term::Out);
indices.insert(Term::Out, i.clone());
i.inc();
let mid_beg = i.clone();
for x in mid {
witness.push(x.clone());
indices.insert(x.clone(), i.clone());
i.inc();
}
mid_beg
}
fn categorize_witness_terms(
t: &Term,
inputs: &mut Vec<Term>,
mid: &mut Vec<Term>,
) {
match t {
Term::One => (), // not categorized as inputs or mid
Term::Num(_) => (), // Num is represented as multiple of Term::One, so not adding to witness
Term::Out => (), // not categorized as inputs or mid
Term::Var(_) => if !inputs.contains(&t) { inputs.push(t.clone()) },
Term::TmpVar(_) => if !mid.contains(&t) { mid.push(t.clone()) },
Term::Sum(a, b) => {
R1CSTmpl::categorize_witness_terms(&a, inputs, mid);
R1CSTmpl::categorize_witness_terms(&b, inputs, mid);
},
}
}
fn build_constraint_vec(
vec: &mut MclSparseVec,
term: &Term,
indices: &HashMap::<Term, MclFr>,
) {
match term {
Term::Sum(a, b) => {
R1CSTmpl::build_constraint_vec(vec, &a, indices);
R1CSTmpl::build_constraint_vec(vec, &b, indices);
},
Term::Num(n) => {
vec.set(&MclFr::zero(), n); // Num is represented as Term::One at index 0 times n
},
x => {
let index = indices.get(&x).unwrap();
vec.set(index, &MclFr::from(1));
},
}
}
pub fn new(gates: &[Gate]) -> Self {
let mut witness = vec![];
let mut indices = HashMap::<Term, MclFr>::new();
// add `1` at index 0
witness.push(Term::One);
indices.insert(Term::One, MclFr::zero());
// categoraize terms contained in gates to inputs and mid
let mut inputs = vec![];
let mut mid = vec![];
for gate in gates {
R1CSTmpl::categorize_witness_terms(&gate.a, &mut inputs, &mut mid);
R1CSTmpl::categorize_witness_terms(&gate.b, &mut inputs, &mut mid);
R1CSTmpl::categorize_witness_terms(&gate.c, &mut inputs, &mut mid);
}
let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices);
let vec_size = &MclFr::from(witness.len());
let mut constraints = vec![];
// create a, b anc c vectors for each gate
for gate in gates {
let mut a = MclSparseVec::new(vec_size);
R1CSTmpl::build_constraint_vec(&mut a, &gate.a, &indices);
let mut b = MclSparseVec::new(vec_size);
R1CSTmpl::build_constraint_vec(&mut b, &gate.b, &indices);
let mut c = MclSparseVec::new(vec_size);
R1CSTmpl::build_constraint_vec(&mut c, &gate.c, &indices);
let constraint = Constraint { a, b, c };
constraints.push(constraint)
}
R1CSTmpl {
constraints,
witness,
indices,
mid_beg,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::building_block::mcl::{
qap::equation_parser::EquationParser,
mcl_initializer::MclInitializer,
};
#[test]
fn test_categorize_witness_terms() {
MclInitializer::init();
// Num term should not be categorized as input or mid
{
let mut inputs = vec![];
let mut mid = vec![];
let term = &Term::Num(MclFr::from(9));
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
assert_eq!(inputs.len(), 0);
assert_eq!(mid.len(), 0);
}
// One term should not be categorized as input or mid
{
let mut inputs = vec![];
let mut mid = vec![];
let term = &Term::One;
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
assert_eq!(inputs.len(), 0);
assert_eq!(mid.len(), 0);
}
// Var term should be categorized as input
{
let mut inputs = vec![];
let mut mid = vec![];
let term = &Term::Var("x".to_string());
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
assert_eq!(inputs.len(), 1);
assert_eq!(mid.len(), 0);
}
// Out term should be not categorized as input or mid
{
let mut inputs = vec![];
let mut mid = vec![];
let term = &Term::Out;
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
assert_eq!(inputs.len(), 0);
assert_eq!(mid.len(), 0);
}
// TmpVar term should be categorized as mid
{
let mut inputs = vec![];
let mut mid = vec![];
let term = &Term::TmpVar(1);
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
assert_eq!(inputs.len(), 0);
assert_eq!(mid.len(), 1);
}
// Sum term should be recursively categorized
{
let mut inputs = vec![];
let mut mid = vec![];
let y = Term::Var("y".to_string());
let z = Term::Var("z".to_string());
let term = &Term::Sum(Box::new(y.clone()), Box::new(z.clone()));
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
assert_eq!(inputs.len(), 2);
assert_eq!(mid.len(), 0);
}
}
#[test]
fn test_build_witness() {
MclInitializer::init();
let a = Term::Var("a".to_string());
let b = Term::Var("b".to_string());
let sum = Term::Sum(Box::new(a), Box::new(b));
let terms = vec![
Term::Num(MclFr::from(9)), // Num should be ignored
Term::One, // One is discarded in categorize_witness_terms
Term::Var("x".to_string()), // Var should be added
Term::Var("x".to_string()), // the same Var should be added twice
Term::Var("y".to_string()), // different Var should be added
Term::TmpVar(1), // TmpVar should be added
Term::TmpVar(1), // same TmpVar should not be added twice
Term::TmpVar(2), // different TmpVar should be added
Term::Out, // Out should be added
Term::Out, // Out should not be added twice
sum, // sum should be added recursively
];
let mut inputs = vec![];
let mut mid = vec![];
for term in &terms {
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
}
let mut witness = vec![];
let mut indices = HashMap::<Term, MclFr>::new();
let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices);
assert!(mid_beg == MclFr::from(6));
// 7 since One has been discarded and Out is added in build_witness
assert_eq!(indices.len(), 7);
assert_eq!(witness.len(), 7);
// check if witness is correctly built
let exp = vec![
// One has been discarded
Term::Var("x".to_string()),
Term::Var("y".to_string()),
Term::Var("a".to_string()),
Term::Var("b".to_string()),
Term::Out, // build_witness adds Out
Term::TmpVar(1),
Term::TmpVar(2),
];
assert!(witness == exp);
// check if indices map is correctly built
assert!(indices.get(&Term::One).is_none());
assert!(indices.get(&Term::Var("x".to_string())).unwrap() == &MclFr::from(1));
assert!(indices.get(&Term::Var("y".to_string())).unwrap() == &MclFr::from(2));
assert!(indices.get(&Term::Var("a".to_string())).unwrap() == &MclFr::from(3));
assert!(indices.get(&Term::Var("b".to_string())).unwrap() == &MclFr::from(4));
assert!(indices.get(&Term::Out).unwrap() == &MclFr::from(5));
assert!(indices.get(&Term::TmpVar(1)).unwrap() == &MclFr::from(6));
assert!(indices.get(&Term::TmpVar(2)).unwrap() == &MclFr::from(7));
}
#[test]
fn test_new() {
MclInitializer::init();
let gates = vec![];
let tmpl = R1CSTmpl::new( &gates);
assert_eq!(tmpl.indices.len(), 2);
// if gates is empty, witness should contain only One term and Out term
assert_eq!(tmpl.indices.get(&Term::One).unwrap(), &MclFr::from(0));
assert_eq!(tmpl.indices.get(&Term::Out).unwrap(), &MclFr::from(1));
assert_eq!(tmpl.witness.len(), 2);
assert_eq!(tmpl.witness[0], Term::One);
assert_eq!(tmpl.witness[1], Term::Out);
}
#[test]
fn test_constraint_generation() {
MclInitializer::init();
{
// Num
let mut inputs = vec![];
let mut mid = vec![];
let term = &Term::Num(MclFr::from(4));
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
let mut witness = vec![];
let mut indices = HashMap::<Term, MclFr>::new();
let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices);
assert!(mid_beg == MclFr::from(2));
let mut constraint = MclSparseVec::new(&MclFr::from(2));
R1CSTmpl::build_constraint_vec(&mut constraint, &term, &indices);
// should be mapped to One term at index 0
assert_eq!(constraint.get(&MclFr::zero()), &MclFr::from(4));
}
{
// Sum
let mut inputs = vec![];
let mut mid = vec![];
let y = Term::Var("y".to_string());
let z = Term::Var("z".to_string());
let term = &Term::Sum(Box::new(y.clone()), Box::new(z.clone()));
R1CSTmpl::categorize_witness_terms(term, &mut inputs, &mut mid);
let mut witness = vec![];
let mut indices = HashMap::<Term, MclFr>::new();
let mid_beg = R1CSTmpl::build_witness(&inputs, &mid, &mut witness, &mut indices);
assert!(mid_beg == MclFr::from(4));
let mut constraint = MclSparseVec::new(&MclFr::from(3));
R1CSTmpl::build_constraint_vec(&mut constraint, &term, &indices);
// y and z should be stored at index 1 and 2 of witness vector respectively
assert_eq!(constraint.get(&MclFr::from(1)), &MclFr::from(1));
assert_eq!(constraint.get(&MclFr::from(1)), &MclFr::from(1));
}
}
#[test]
fn test_witness_indices() {
MclInitializer::init();
let input = "(3 * x + 4) / 2 == 11";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
let tmpl = R1CSTmpl::new(gates);
let h = tmpl.indices;
let w = [
Term::One,
Term::Var("x".to_string()),
Term::Out,
Term::TmpVar(1),
Term::TmpVar(2),
Term::TmpVar(3),
];
assert_eq!(h.len(), w.len());
for (i, term) in w.iter().enumerate() {
assert_eq!(h.get(&term).unwrap(), &MclFr::from(i));
}
}
fn term_to_str(tmpl: &R1CSTmpl, vec: &MclSparseVec) -> String {
MclInitializer::init();
let mut indices = vec.indices().to_vec();
indices.sort(); // sort to make indices order deterministic
let s = indices.iter().map(|i| {
let i_usize: usize = i.to_usize();
match &tmpl.witness[i_usize] {
Term::Var(s) => s.clone(),
Term::TmpVar(i) => format!("t{}", i),
Term::One => format!("{:?}", &vec.get(i)),
Term::Out => "out".to_string(),
// currently not handling Term::Sum since it's not used in tests
_ => "?".to_string(),
}
}).collect::<Vec<String>>().join(" + ");
format!("{}", s)
}
#[test]
fn test_r1cs_build_a_b_c_matrix() {
MclInitializer::init();
let input = "3 * x + 4 == 11";
let eq = EquationParser::parse(input).unwrap();
let gates = &Gate::build(&eq);
let tmpl = R1CSTmpl::new(gates);
let mut res = vec![];
for constraint in &tmpl.constraints {
let a = term_to_str(&tmpl, &constraint.a);
let b = term_to_str(&tmpl, &constraint.b);
let c = term_to_str(&tmpl, &constraint.c);
res.push((a, b, c));
}
assert_eq!(res.len(), 3);
assert_eq!(res[0], ("3".to_string(), "x".to_string(), "t1".to_string()));
assert_eq!(res[1], ("4 + t1".to_string(), "1".to_string(), "t2".to_string()));
assert_eq!(res[2], ("t2".to_string(), "1".to_string(), "out".to_string()));
}
#[test]
fn blog_post_1_example_1() {
MclInitializer::init();
let expr = "(x * x * x) + x + 5 == 35";
let eq = EquationParser::parse(expr).unwrap();
let gates = &Gate::build(&eq);
let r1cs_tmpl = R1CSTmpl::new(gates);
println!("{:?}", r1cs_tmpl.witness);
}
#[test]
fn blog_post_1_example_2() {
MclInitializer::init();
let expr = "(x * x * x) + x + 5 == 35";
let eq = EquationParser::parse(expr).unwrap();
let gates = &Gate::build(&eq);
let r1cs_tmpl = R1CSTmpl::new(gates);
println!("w = {:?}", r1cs_tmpl.witness);
println!("{:?}", r1cs_tmpl.constraints);
}
}

View File

@@ -0,0 +1,33 @@
use crate::building_block::mcl::{
mcl_fr::MclFr,
qap::config::SignalId,
};
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Term {
Num(MclFr),
One,
Out,
Sum(Box<Term>, Box<Term>), // Sum will not contain Out and Sum itself
TmpVar(SignalId),
Var(String),
}
impl Term {
pub fn var(name: &str) -> Term {
Term::Var(name.to_string())
}
}
impl std::fmt::Debug for Term {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Term::Num(n) => write!(f, "{:?}", n),
Term::One => write!(f, "1"),
Term::Out => write!(f, "out"),
Term::Sum(a, b) => write!(f, "({:?} + {:?})", a, b),
Term::TmpVar(n) => write!(f, "t{:?}", n),
Term::Var(s) => write!(f, "{:?}", s),
}
}
}

View File

@@ -1,6 +1,7 @@
// pub mod elliptic_curve;
pub mod field;
pub mod hasher;
pub mod mcl;
pub mod random_number;
pub mod curves;
pub mod to_bigint;

View File

@@ -0,0 +1,138 @@
use crate::{
building_block::mcl::{
mcl_fr::MclFr,
mcl_g1::MclG1,
mcl_g2::MclG2,
mcl_gt::MclGT,
pairing::Pairing,
qap::qap::QAP,
},
zk::w_trusted_setup::groth16_mcl::prover::Prover,
};
pub struct G1 {
pub alpha: MclG1,
pub beta: MclG1,
pub delta: MclG1,
pub xi: Vec<MclG1>, // x powers
pub uvw_stmt: Vec<MclG1>, // beta*u(x) + alpha*v(x) + w(x) / div (statement)
pub uvw_wit: Vec<MclG1>, // beta*u(x) + alpha*v(x) + w(x) / div (witness)
pub ht_by_delta: MclG1, // h(x) * t(x) / delta
}
pub struct G2 {
pub beta: MclG2,
pub gamma: MclG2,
pub delta: MclG2,
pub xi: Vec<MclG2>, // x powers
}
pub struct GT {
pub alpha_beta: MclGT,
}
#[allow(non_snake_case)]
pub struct CRS {
pub g1: G1,
pub g2: G2,
pub gt: GT,
}
impl CRS {
// 0, 1, .., l, l+1, .., m
// +---------+ +--------+
// statement witness
pub fn new(
prover: &Prover,
pairing: &Pairing,
) -> Self {
println!("--> Building sigma...");
let g = &MclG1::g();
let h = &MclG2::g();
// sample random non-zero field element
let alpha = &MclFr::rand(true);
let beta = &MclFr::rand(true);
let gamma = &MclFr::rand(true);
let delta = &MclFr::rand(true);
let x = &MclFr::rand(true);
macro_rules! calc_uvw_div {
($from: expr, $to: expr, $div_factor: expr) => {
{
let mut ys: Vec<MclG1> = vec![];
let mut i = $from.clone();
while &i <= $to {
let ui = beta * &prover.ui[i].eval_at(x);
let vi = alpha * &prover.vi[i].eval_at(x);
let wi = &prover.wi[i].eval_at(x);
let y = (ui + vi + wi) * $div_factor;
ys.push(g * y);
i += 1;
}
ys
}
}
}
let uvw_stmt = calc_uvw_div!(0, &prover.l.to_usize(), &gamma.inv());
let uvw_wit = calc_uvw_div!(&prover.l.to_usize() + 1, &prover.m.to_usize(), &delta.inv());
macro_rules! calc_n_pows {
($point_type: ty, $x: expr) => {
{
let generator = &<$point_type>::g();
let mut ys: Vec<$point_type> = vec![];
let mut x_pow = MclFr::from(1);
for _ in 0..prover.n.to_usize(){
ys.push(generator * &x_pow);
x_pow = x_pow * x;
}
ys
}
}
}
let xi_g1 = calc_n_pows!(MclG1, x);
let ht_by_delta = {
let h = &prover.h.eval_at(x);
let t = &QAP::build_t(&prover.n).eval_at(x);
let v = h * t * &delta.inv();
MclG1::g() * &v
};
let g1 = G1 {
alpha: g * alpha,
beta: g * beta,
delta: g * delta,
xi: xi_g1,
uvw_stmt,
uvw_wit,
ht_by_delta,
};
let xi_g2 = calc_n_pows!(MclG2, x);
let g2 = G2 {
beta: h * beta,
gamma: h * gamma,
delta: h * delta,
xi: xi_g2,
};
let gt = GT {
alpha_beta: pairing.e(&g1.alpha, &g2.beta),
};
CRS {
g1,
g2,
gt,
}
}
}

View File

@@ -0,0 +1,5 @@
pub mod crs;
pub mod proof;
pub mod prover;
pub mod verifier;
pub mod wires;

View File

@@ -0,0 +1,12 @@
use crate::building_block::mcl::{
mcl_g1::MclG1,
mcl_g2::MclG2,
};
#[allow(non_snake_case)]
pub struct Proof {
pub A: MclG1,
pub B: MclG2,
pub C: MclG1,
}

View File

@@ -0,0 +1,179 @@
use crate::{
building_block::mcl::{
mcl_fr::MclFr,
mcl_g1::MclG1,
mcl_g2::MclG2,
polynomial::{
DivResult,
Polynomial,
},
qap::{
equation_parser::EquationParser,
gate::Gate,
qap::QAP,
r1cs::R1CS,
r1cs_tmpl::R1CSTmpl,
term::Term,
},
},
zk::w_trusted_setup::groth16_mcl::{
crs::CRS,
proof::Proof,
wires::Wires,
},
};
use num_traits::Zero;
use std::collections::HashMap;
pub struct Prover {
pub n: MclFr, // # of constraints
pub l: MclFr, // end index of statement variables
pub m: MclFr, // end index of statement + witness variables
pub wires: Wires,
pub h: Polynomial,
pub t: Polynomial,
pub ui: Vec<Polynomial>,
pub vi: Vec<Polynomial>,
pub wi: Vec<Polynomial>,
}
impl Prover {
pub fn new(
expr: &str,
witness_map: &HashMap<Term, MclFr>,
) -> Self {
let eq = EquationParser::parse(expr).unwrap();
let gates = &Gate::build(&eq);
let tmpl = &R1CSTmpl::new(gates);
let r1cs = R1CS::from_tmpl(tmpl, &witness_map).unwrap();
r1cs.validate().unwrap();
let qap = QAP::build(&r1cs);
let t = QAP::build_t(&MclFr::from(tmpl.constraints.len()));
let h = {
let p = qap.build_p(&r1cs.witness);
match p.divide_by(&t) {
DivResult::Quotient(q) => q,
_ => panic!("p should be divisible by t"),
}
};
let l = &tmpl.mid_beg - MclFr::from(1);
let m = MclFr::from(tmpl.witness.len() - 1);
let wires = Wires::new(&r1cs.witness.clone(), &l);
let n = MclFr::from(tmpl.constraints.len());
Prover {
n,
l,
m,
wires,
t,
h,
ui: qap.vi.clone(),
vi: qap.wi.clone(),
wi: qap.yi.clone(),
}
}
#[allow(non_snake_case)]
pub fn prove(&self, crs: &CRS) -> Proof {
println!("--> Generating proof...");
let r = &MclFr::rand(true);
let s = &MclFr::rand(true);
let (A, B, B_g1) = {
let mut sum_term_A = MclG1::zero();
let mut sum_term_B = MclG2::zero();
let mut sum_term_B_g1 = MclG1::zero();
for i in 0..=self.m.to_usize() {
let ai = &self.wires[i];
let ui_prod = self.ui[i].eval_with_g1_hidings(&crs.g1.xi) * ai;
let vi_prod = self.vi[i].eval_with_g2_hidings(&crs.g2.xi) * ai;
let vi_prod_g1 = self.vi[i].eval_with_g1_hidings(&crs.g1.xi) * ai;
sum_term_A += ui_prod;
sum_term_B += vi_prod;
sum_term_B_g1 += vi_prod_g1;
}
let A = &crs.g1.alpha + &sum_term_A + &crs.g1.delta * r;
let B = &crs.g2.beta + &sum_term_B + &crs.g2.delta * s;
let B_g1 = &crs.g1.beta + &sum_term_B_g1 + &crs.g1.delta * s;
(A, B, B_g1)
};
let C = {
let mut sum = MclG1::zero();
let wit_beg = self.l.to_usize() + 1;
for i in wit_beg..=self.m.to_usize() {
let ai = &self.wires[i];
sum += &crs.g1.uvw_wit[i - wit_beg] * ai;
}
sum
+ &crs.g1.ht_by_delta
+ &A * s
+ &B_g1 * r
+ -(&crs.g1.delta * r * s)
};
Proof {
A,
B,
C,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
building_block::mcl::{
pairing::Pairing,
mcl_initializer::MclInitializer,
},
zk::w_trusted_setup::groth16_mcl::verifier::Verifier,
};
#[test]
fn test_generate_proof_and_verify() {
MclInitializer::init();
let expr = "(x * x * x) + x + 5 == 35";
println!("Expr: {}\n", expr);
let eq = EquationParser::parse(expr).unwrap();
let witness_map = {
use crate::building_block::mcl::qap::term::Term::*;
HashMap::<Term, MclFr>::from([
(Term::One, MclFr::from(1)),
(Term::var("x"), MclFr::from(3)),
(TmpVar(1), MclFr::from(9)),
(TmpVar(2), MclFr::from(27)),
(TmpVar(3), MclFr::from(8)),
(TmpVar(4), MclFr::from(35)),
(Out, eq.rhs),
])
};
let prover = &Prover::new(expr, &witness_map);
let pairing = &Pairing;
let verifier = &Verifier::new(pairing);
let crs = CRS::new(prover, pairing);
let proof = prover.prove(&crs);
let stmt_wires = &prover.wires.statement();
let result = verifier.verify(
&proof,
&crs,
stmt_wires,
);
assert!(result);
}
}

View File

@@ -0,0 +1,55 @@
// Implementation of protocol 2 described on page 5 in https://eprint.iacr.org/2013/279.pdf
use crate::{
building_block::mcl::{
mcl_fr::MclFr,
mcl_g1::MclG1,
mcl_g2::MclG2,
mcl_sparse_vec::MclSparseVec,
pairing::Pairing,
},
zk::w_trusted_setup::groth16_mcl::{
crs::CRS,
proof::Proof,
},
};
use num_traits::Zero;
pub struct Verifier {
pairing: Pairing,
}
impl Verifier {
pub fn new(pairing: &Pairing) -> Self {
Verifier {
pairing: pairing.clone(),
}
}
pub fn verify(
&self,
proof: &Proof,
crs: &CRS,
stmt_wires: &MclSparseVec,
) -> bool {
let e = |a: &MclG1, b: &MclG2| self.pairing.e(a, b);
println!("--> Verifying Groth16 proof...");
let lhs = e(&proof.A, &proof.B);
let mut sum_term = MclG1::zero();
for i in 0..stmt_wires.size.to_usize() {
let ai = &stmt_wires[&MclFr::from(i)];
sum_term += &crs.g1.uvw_stmt[i] * ai;
}
let rhs =
&crs.gt.alpha_beta
* e(&sum_term, &crs.g2.gamma)
* e(&proof.C, &crs.g2.delta)
;
lhs == rhs
}
}

View File

@@ -0,0 +1,78 @@
use crate::building_block::mcl::{
mcl_fr::MclFr,
mcl_sparse_vec::MclSparseVec,
};
use core::ops::Index;
use num_traits::Zero;
// wires:
// 0, 1, .., l, l+1, .., m
// +---------+ +--------+
// statement witness
pub struct Wires {
sv: MclSparseVec,
witness_beg: MclFr,
}
impl Wires {
// l is index of the last statement wire
pub fn new(sv: &MclSparseVec, l: &MclFr) -> Self {
Wires {
sv: sv.clone(),
witness_beg: l + MclFr::from(1),
}
}
pub fn statement(&self) -> MclSparseVec {
self.sv.slice(&MclFr::zero(), &self.witness_beg)
}
pub fn witness(&self) -> MclSparseVec {
let from = &self.witness_beg;
let to = &self.sv.size;
self.sv.slice(from, to)
}
}
impl Index<usize> for Wires {
type Output = MclFr;
fn index(&self, index: usize) -> &Self::Output {
&self.sv[&MclFr::from(index)]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wire_indices() {
// [1,3,35,9,27,8,35]
let mut sv = MclSparseVec::new(&MclFr::from(7));
sv[&MclFr::from(0)] = MclFr::from(1);
sv[&MclFr::from(1)] = MclFr::from(3);
sv[&MclFr::from(2)] = MclFr::from(35); // <-- l
sv[&MclFr::from(3)] = MclFr::from(9);
sv[&MclFr::from(4)] = MclFr::from(27);
sv[&MclFr::from(5)] = MclFr::from(8);
sv[&MclFr::from(6)] = MclFr::from(35);
let w = Wires::new(&sv, &MclFr::from(2));
let st = &w.statement();
assert!(st.size == MclFr::from(3));
assert!(st[&MclFr::from(0)] == MclFr::from(1));
assert!(st[&MclFr::from(1)] == MclFr::from(3));
assert!(st[&MclFr::from(2)] == MclFr::from(35));
let wit = &w.witness();
assert!(wit.size == MclFr::from(4));
assert!(wit[&MclFr::from(0)] == MclFr::from(9));
assert!(wit[&MclFr::from(1)] == MclFr::from(27));
assert!(wit[&MclFr::from(2)] == MclFr::from(8));
assert!(wit[&MclFr::from(3)] == MclFr::from(35));
}
}

View File

@@ -1,3 +1,4 @@
pub mod groth16;
pub mod groth16_mcl;
pub mod pinocchio;
pub mod qap;