halo2 wip and tests

This commit is contained in:
themighty1
2022-09-21 21:07:35 +03:00
parent cffba8cc6e
commit 455420b282
36 changed files with 2876 additions and 2108 deletions

View File

@@ -52,3 +52,6 @@ ark-std = { git = "https://github.com/arkworks-rs/std" }
ark-ec = { git = "https://github.com/arkworks-rs/algebra" }
ark-ff = { git = "https://github.com/arkworks-rs/algebra" }
ark-serialize = { git = "https://github.com/arkworks-rs/algebra" }
[dev-dependencies]
hex = "0.4"

29
README
View File

@@ -1,29 +0,0 @@
This repo generates a circom circuit which is used to decode output labels from GC.
Install snarkjs https://github.com/iden3/snarkjs
Download powers of tau^14 https://hermez.s3-eu-west-1.amazonaws.com/powersOfTau28_hez_final_14.ptau
Run:
python3 script.py 15
# 10 is how much plaintext (in Field elements of ~32 bytes) we want
# to decode inside the snark. (For tau^14 max is 21)
# if you need more than 21, you'll need to download another ptau file from
# https://github.com/iden3/snarkjs#7-prepare-phase-2
circom circuit.circom --r1cs --wasm
snarkjs groth16 setup circuit.r1cs powersOfTau28_hez_final_14.ptau circuit_0000.zkey
# snarkjs groth16 setup circuit.r1cs pot14_bls12_final.ptau circuit_0000.zkey
snarkjs zkey contribute circuit_0000.zkey circuit_final.zkey -v -e="Notary's entropy"
snarkjs zkey export verificationkey circuit_final.zkey verification_key.json
snarkjs groth16 fullprove input.json circuit_js/circuit.wasm circuit_final.zkey proof.json public.json
snarkjs groth16 verify verification_key.json public.json proof.json
We can generate circuit.wasm and circuit.r1cs deterministically with circom 2.0.5+
circom circuit.circom --r1cs --wasm
and then ship .wasm on the User side and .r1cs on the Notary side

24
circom/README Normal file
View File

@@ -0,0 +1,24 @@
This folder contains files used by the snarkjs_backend.
To use that backend, make sure you have node installed (tested on Node v16.17.1)
Install dependencies with:
npm install
powersOfTau28_hez_final_14.ptau was downloaded from https://hermez.s3-eu-west-1.amazonaws.com/powersOfTau28_hez_final_14.ptau
Whenever circuit.circom is modified, delete circuit_0000.zkey and run:
circom circuit.circom --r1cs --wasm
All the commands below will be run by the prover/verifier from the .mjs files:
snarkjs groth16 setup circuit.r1cs powersOfTau28_hez_final_14.ptau circuit_0000.zkey
snarkjs zkey contribute circuit_0000.zkey circuit_final.zkey -v -e="Notary's entropy"
snarkjs zkey export verificationkey circuit_final.zkey verification_key.json
snarkjs groth16 fullprove input.json circuit_js/circuit.wasm circuit_final.zkey proof.json public.json
snarkjs groth16 verify verification_key.json public.json proof.json
We can generate circuit.wasm and circuit.r1cs deterministically with circom 2.0.5+
circom circuit.circom --r1cs --wasm
and then ship circuit.wasm on the User side and circuit.r1cs on the Notary side

View File

@@ -3,7 +3,7 @@ include "./poseidon.circom";
include "./utils.circom";
template Main() {
// Poseidon hash width (how many field elements are permuted at a time)
// Poseidon hash rate (how many field elements are permuted at a time)
var w = 16;
// The amount of last field element's high bits (in big-endian) to use for
// the plaintext. The rest of it will be used for the salt.
@@ -12,19 +12,20 @@ template Main() {
signal input plaintext_hash;
signal input label_sum_hash;
signal input plaintext[w];
signal input labelsum_salt;
signal input salt;
signal input delta[w-1][253];
signal input delta_last[last_fe_bits];
signal input sum_of_zero_labels;
signal sums[w];
// acc.to the Poseidon paper, the 2nd element of the Poseidon state
// is the hash digest
component hash = PoseidonEx(w, 2);
hash.initialState <== 0;
for (var i = 0; i<w; i++) {
for (var i = 0; i < w-1; i++) {
hash.inputs[i] <== plaintext[i];
}
//add salt to the last element of plaintext shifting it left first
hash.inputs[w-1] <== plaintext[w-1] * (1 << 128) + salt;
log(1);
plaintext_hash === hash.out[1];
log(2);
@@ -41,7 +42,7 @@ template Main() {
var useful_bits = i < w-1 ? 253 : last_fe_bits;
ip[i] = InnerProd(useful_bits);
ip[i].plaintext <== plaintext[i];
for (var j=0; j<useful_bits; j++) {
for (var j=0; j < useful_bits; j++) {
if (i < w-1){
ip[i].deltas[j] <== delta[i][j];
}
@@ -57,8 +58,9 @@ template Main() {
component ls_hash = PoseidonEx(1, 2);
ls_hash.initialState <== 0;
// shift the sum to the left and put the salt into the last 128 bits
ls_hash.inputs[0] <== (sum_of_zero_labels + sum_of_deltas[w]) * (1 << 128) + labelsum_salt;
ls_hash.inputs[0] <== (sum_of_zero_labels + sum_of_deltas[w]) * (1 << 128) + salt;
log(3);
label_sum_hash === ls_hash.out[1];
log(4);
}
component main {public [sum_of_zero_labels, plaintext_hash, label_sum_hash, delta, delta_last]} = Main();

BIN
circom/circuit.r1cs Normal file

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,20 @@
const wc = require("./witness_calculator.js");
const { readFileSync, writeFile } = require("fs");
if (process.argv.length != 5) {
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
} else {
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
const buffer = readFileSync(process.argv[2]);
wc(buffer).then(async witnessCalculator => {
// const w= await witnessCalculator.calculateWitness(input,0);
// for (let i=0; i< w.length; i++){
// console.log(w[i]);
// }
const buff= await witnessCalculator.calculateWTNSBin(input,0);
writeFile(process.argv[4], buff, function(err) {
if (err) throw err;
});
});
}

View File

@@ -0,0 +1,306 @@
module.exports = async function builder(code, options) {
options = options || {};
let wasmModule;
try {
wasmModule = await WebAssembly.compile(code);
} catch (err) {
console.log(err);
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
throw new Error(err);
}
let wc;
const instance = await WebAssembly.instantiate(wasmModule, {
runtime: {
exceptionHandler : function(code) {
let errStr;
if (code == 1) {
errStr= "Signal not found. ";
} else if (code == 2) {
errStr= "Too many signals set. ";
} else if (code == 3) {
errStr= "Signal already set. ";
} else if (code == 4) {
errStr= "Assert Failed. ";
} else if (code == 5) {
errStr= "Not enough memory. ";
} else if (code == 6) {
errStr= "Input signal array access exceeds the size";
} else {
errStr= "Unknown error\n";
}
// get error message from wasm
errStr += getMessage();
throw new Error(errStr);
},
showSharedRWMemory: function() {
printSharedRWMemory ();
}
}
});
const sanityCheck =
options
// options &&
// (
// options.sanityCheck ||
// options.logGetSignal ||
// options.logSetSignal ||
// options.logStartComponent ||
// options.logFinishComponent
// );
wc = new WitnessCalculator(instance, sanityCheck);
return wc;
function getMessage() {
var message = "";
var c = instance.exports.getMessageChar();
while ( c != 0 ) {
message += String.fromCharCode(c);
c = instance.exports.getMessageChar();
}
return message;
}
function printSharedRWMemory () {
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
const arr = new Uint32Array(shared_rw_memory_size);
for (let j=0; j<shared_rw_memory_size; j++) {
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
}
console.log(fromArray32(arr));
}
};
class WitnessCalculator {
constructor(instance, sanityCheck) {
this.instance = instance;
this.version = this.instance.exports.getVersion();
this.n32 = this.instance.exports.getFieldNumLen32();
this.instance.exports.getRawPrime();
const arr = new Uint32Array(this.n32);
for (let i=0; i<this.n32; i++) {
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
}
this.prime = fromArray32(arr);
this.witnessSize = this.instance.exports.getWitnessSize();
this.sanityCheck = sanityCheck;
}
circom_version() {
return this.instance.exports.getVersion();
}
async _doCalculateWitness(input, sanityCheck) {
//input is assumed to be a map from signals to arrays of bigints
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
const keys = Object.keys(input);
var input_counter = 0;
keys.forEach( (k) => {
const h = fnvHash(k);
const hMSB = parseInt(h.slice(0,8), 16);
const hLSB = parseInt(h.slice(8,16), 16);
const fArr = flatArray(input[k]);
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
if (signalSize < 0){
throw new Error(`Signal ${k} not found\n`);
}
if (fArr.length < signalSize) {
throw new Error(`Not enough values for input signal ${k}\n`);
}
if (fArr.length > signalSize) {
throw new Error(`Too many values for input signal ${k}\n`);
}
for (let i=0; i<fArr.length; i++) {
const arrFr = toArray32(BigInt(fArr[i])%this.prime,this.n32)
for (let j=0; j<this.n32; j++) {
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
}
try {
this.instance.exports.setInputSignal(hMSB, hLSB,i);
input_counter++;
} catch (err) {
// console.log(`After adding signal ${i} of ${k}`)
throw new Error(err);
}
}
});
if (input_counter < this.instance.exports.getInputSize()) {
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
}
}
async calculateWitness(input, sanityCheck) {
const w = [];
await this._doCalculateWitness(input, sanityCheck);
for (let i=0; i<this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const arr = new Uint32Array(this.n32);
for (let j=0; j<this.n32; j++) {
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
}
w.push(fromArray32(arr));
}
return w;
}
async calculateBinWitness(input, sanityCheck) {
const buff32 = new Uint32Array(this.witnessSize*this.n32);
const buff = new Uint8Array( buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
for (let i=0; i<this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const pos = i*this.n32;
for (let j=0; j<this.n32; j++) {
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
}
}
return buff;
}
async calculateWTNSBin(input, sanityCheck) {
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
const buff = new Uint8Array( buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
//"wtns"
buff[0] = "w".charCodeAt(0)
buff[1] = "t".charCodeAt(0)
buff[2] = "n".charCodeAt(0)
buff[3] = "s".charCodeAt(0)
//version 2
buff32[1] = 2;
//number of sections: 2
buff32[2] = 2;
//id section 1
buff32[3] = 1;
const n8 = this.n32*4;
//id section 1 length in 64bytes
const idSection1length = 8 + n8;
const idSection1lengthHex = idSection1length.toString(16);
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
//this.n32
buff32[6] = n8;
//prime number
this.instance.exports.getRawPrime();
var pos = 7;
for (let j=0; j<this.n32; j++) {
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
// witness size
buff32[pos] = this.witnessSize;
pos++;
//id section 2
buff32[pos] = 2;
pos++;
// section 2 length
const idSection2length = n8*this.witnessSize;
const idSection2lengthHex = idSection2length.toString(16);
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
pos += 2;
for (let i=0; i<this.witnessSize; i++) {
this.instance.exports.getWitness(i);
for (let j=0; j<this.n32; j++) {
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
}
return buff;
}
}
function toArray32(rem,size) {
const res = []; //new Uint32Array(size); //has no unshift
const radix = BigInt(0x100000000);
while (rem) {
res.unshift( Number(rem % radix));
rem = rem / radix;
}
if (size) {
var i = size - res.length;
while (i>0) {
res.unshift(0);
i--;
}
}
return res;
}
function fromArray32(arr) { //returns a BigInt
var res = BigInt(0);
const radix = BigInt(0x100000000);
for (let i = 0; i<arr.length; i++) {
res = res*radix + BigInt(arr[i]);
}
return res;
}
function flatArray(a) {
var res = [];
fillArray(res, a);
return res;
function fillArray(res, a) {
if (Array.isArray(a)) {
for (let i=0; i<a.length; i++) {
fillArray(res, a[i]);
}
} else {
res.push(a);
}
}
}
function fnvHash(str) {
const uint64_max = BigInt(2) ** BigInt(64);
let hash = BigInt("0xCBF29CE484222325");
for (var i = 0; i < str.length; i++) {
hash ^= BigInt(str[i].charCodeAt());
hash *= BigInt(0x100000001B3);
hash %= uint64_max;
}
let shash = hash.toString(16);
let n = 16 - shash.length;
shash = '0'.repeat(n).concat(shash);
return shash;
}

View File

@@ -17,17 +17,17 @@ async function main(){
process.exit(1);
}
const r1cs = fs.readFileSync("circuit.r1cs");
const ptau = fs.readFileSync("powersOfTau28_hez_final_14.ptau");
const r1cs = fs.readFileSync("circom/circuit.r1cs");
const ptau = fs.readFileSync("circom/powersOfTau28_hez_final_14.ptau");
// snarkjs groth16 setup circuit.r1cs powersOfTau28_hez_final_14.ptau circuit_0000.zkey
const zkey_0 = {type: "file", fileName: "circuit_0000.zkey"};
const zkey_0 = {type: "file", fileName: "circom/circuit_0000.zkey"};
await createOverride(zkey_0);
console.log("groth16 setup...");
await snarkjs.zKey.newZKey(r1cs, ptau, zkey_0);
// snarkjs zkey contribute circuit_0000.zkey circuit_final.zkey -e="<Notary's entropy>"
const zkey_final = {type: "file", fileName: "circuit_final.zkey.notary"};
const zkey_final = {type: "file", fileName: "circom/circuit_final.zkey.notary"};
await createOverride(zkey_final);
console.log("zkey contribute...");
await snarkjs.zKey.contribute(zkey_0, zkey_final, "", entropy);
@@ -36,7 +36,7 @@ async function main(){
console.log("zkey export...");
const vKey = await snarkjs.zKey.exportVerificationKey(zkey_final);
// copied from snarkjs/cli.js zkeyExportVKey()
await bfj.write("verification_key.json", stringifyBigInts(vKey), { space: 1 });
await bfj.write("circom/verification_key.json", stringifyBigInts(vKey), { space: 1 });
}
main().then(() => {

8
circom/package.json Normal file
View File

@@ -0,0 +1,8 @@
{
"dependencies": {
"circom": "^0.5.46",
"circom2": "^0.2.5",
"circomlibjs": "^0.1.7",
"snarkjs": "^0.4.24"
}
}

Binary file not shown.

View File

@@ -13,7 +13,7 @@ async function main(){
const proof_path = process.argv[4];
const input = fs.readFileSync(input_path);
const wasm = fs.readFileSync(path.join("circuit_js", "circuit.wasm"));
const wasm = fs.readFileSync(path.join("circom", "circuit_js", "circuit.wasm"));
const zkey_final = fs.readFileSync(proving_key_path);
const in_json = JSON.parse(input);

View File

@@ -10,7 +10,7 @@ async function main(retval){
const pub_path = process.argv[2];
const proof_path = process.argv[3];
const vk = JSON.parse(fs.readFileSync("verification_key.json", "utf8"));
const vk = JSON.parse(fs.readFileSync("circom/verification_key.json", "utf8"));
const pub = JSON.parse(fs.readFileSync(pub_path, "utf8"));
const proof = JSON.parse(fs.readFileSync(proof_path, "utf8"));

168
script.py
View File

@@ -1,168 +0,0 @@
import sys
import random
import os
def padded_hex(s):
h = hex(s)
l = len(h)
if l % 2 == 0:
return h
else:
return '0x{0:0{1}x}'.format(s,l-1)
# This script will generate the circom circuit and inputs
if __name__ == '__main__':
if len(sys.argv) != 2:
print('Expected 1 argument: amount of plaintext to process (in Field elements). Max 16')
exit(1)
count = int(sys.argv[1])
if count > 16:
print('Max 16 Field elements allowed. Got ', count)
exit(1)
input = '{\n'
input += '"plaintext_hash": "'+str(random.randint(0, 2**254))+'",\n'
input += '"label_sum_hash": "'+str(random.randint(0, 2**254))+'",\n'
input += '"sum_of_zero_labels": "'+str(random.randint(0, 2**140))+'",\n'
input += '"plaintext": [\n'
for c in range(0, count):
input += ' "'+str(random.randint(0, 2**253))+'"'
if c < count-1:
input += ',\n'
input += "],\n"
input += '"delta": [\n'
for c in range(0, count):
input += ' [\n'
for x in range(0, 254):
input += ' "'+str(random.randint(0, 2**253))+'"'
if x < 253:
input += ',\n'
input += ' ]\n'
if c < count-1:
input += ',\n'
input += ']\n'
input += '}\n'
with open('input.json', 'w') as f:
f.write(input)
main = 'pragma circom 2.0.0;\n'
main += 'include "./poseidon.circom";\n'
main += 'include "./utils.circom";\n'
main += '\n'
main += 'template Main() {\n'
main += ' signal input plaintext_hash;\n'
main += ' signal input label_sum_hash;\n'
main += ' signal input plaintext['+str(count)+'];\n'
main += ' signal input delta['+str(count)+'][253];\n'
main += ' signal input sum_of_zero_labels;\n'
main += ' signal sums['+str(count)+'];\n'
main += '\n'
main += ' component hash = Poseidon('+str(count)+');\n'
main += ' for (var i = 0; i<'+str(count)+'; i++) {\n'
main += ' hash.inputs[i] <== plaintext[i];\n'
main += ' }\n'
main += '\n'
main += " // TODO to pass this assert we'd have to\n"
main += ' // use actual values instead of random ones, so commenting out for now\n'
main += ' plaintext_hash === hash.out;\n'
main += '\n'
# check that Prover's hash is correct. hashing 16 field elements at a time since
# idk how to chain hashes with circomlib.
# Using prev. digest as the first input to the next hash
# if is_final is true then count includes the sum_of_labels
# def hash(no, start, count, is_final=False):
# out = ' component hash_'+str(no)+' = Poseidon('+str(count)+');\n'
# if no > 0:
# #first element is prev. hash digest
# out += ' hash_'+str(no)+'.inputs[0] <== hash_'+str(no-1)+'.out;\n'
# else:
# if is_final and count == 1:
# out += ' hash_'+str(no)+'.inputs[0] <== sum_of_labels;\n'
# else:
# out += ' hash_'+str(no)+'.inputs[0] <== plaintext['+str(start)+'];\n'
# for x in range(1, count-1):
# out += ' hash_'+str(no)+'.inputs['+str(x)+'] <== plaintext['+str(start+x)+'];\n'
# if is_final:
# # sum of labels if the last input
# out += ' hash_'+str(no)+'.inputs['+str(count-1)+'] <== sum_of_labels;\n'
# else:
# out += ' hash_'+str(no)+'.inputs['+str(count-1)+'] <== plaintext['+str(start+count-1)+'];\n'
# out += '\n'
# return out
# def hash_str():
# out = ''
# if count+1 <= 16:
# out += hash(0, 0, count+1, True)
# out += ' prover_hash <== hash_0.out;\n'
# return out
# else:
# out += hash(0, 0, 16, False)
# if count+1 <= 32:
# out += hash(1, 16, count+1-16, True)
# out += ' prover_hash <== hash_1.out;\n'
# return out
# else:
# out += hash(1, 16, 16, False)
# if count+1 <= 48:
# out += hash(2, 16, count+1-32, True)
# out += ' prover_hash <== hash_2.out;\n'
# return out
# else:
# out += hash(2, 16, 16, False)
# if count+1 <= 64:
# out += hash(3, 16, count+1-48, True)
# out += ' prover_hash <== hash_3.out;\n'
# return out
# else:
# out += hash(3, 16, 16, False)
# if count+1 <= 80:
# out += hash(4, 16, count+1-64, True)
# out += ' prover_hash <== hash_3.out;\n'
# return out
# else:
# out += hash(4, 16, 16, False)
# main += '\n'
# main += hash_str()
# main += '\n'
main += ' component ip['+str(count)+'];\n'
main += ' for (var i = 0; i<'+str(count)+'; i++) {\n'
main += ' ip[i] = InnerProd();\n'
main += ' ip[i].plaintext <== plaintext[i];\n'
main += ' for (var j=0; j<253; j++) {\n'
main += ' ip[i].deltas[j] <== delta[i][j];\n'
main += ' }\n'
main += ' sums[i] <== ip[i].out;\n'
main += ' }\n'
main += '\n'
main += ' signal sum_of_deltas <== '
for c in range(0, count):
main += 'sums['+str(c)+']'
if c < count-1:
main += ' + '
else:
main += ';\n'
main += " // TODO to pass this assert we'd have to\n"
main += ' // use actual values instead of random ones, so commenting out for now\n'
main += ' component ls_hash = Poseidon(1);\n'
main += ' ls_hash.inputs[0] <== sum_of_zero_labels + sum_of_deltas;\n'
main += ' label_sum_hash === ls_hash.out;\n'
main += '}\n'
main += 'component main {public [plaintext_hash, label_sum_hash, delta, sum_of_zero_labels]} = Main();'
with open('test.circom', 'w') as f:
f.write(main)

File diff suppressed because it is too large Load Diff

View File

@@ -1,111 +1,53 @@
pub mod circuit;
pub mod onetimesetup;
pub mod poseidon_spec;
pub mod poseidon;
pub mod prover;
pub mod utils;
pub mod verifier;
use crate::halo2_backend::onetimesetup::OneTimeSetup;
use crate::halo2_backend::utils::u8vec_to_boolvec;
use rand::{thread_rng, Rng};
#[test]
fn e2e_test() {
let mut rng = thread_rng();
/// The amount of useful bits, see [crate::prover::Prove::useful_bits].
/// This value is hard-coded into the circuit regardless of whether we use pasta
/// curves (field size 255) or the bn254 curve (field size 254).
const USEFUL_BITS: usize = 253;
let ots = OneTimeSetup::new();
/// The size of the chunk, see [crate::prover::Prove::chunk_size].
/// We use 14 field elements of 253 bits and 128 bits of the 15th field
/// element: 14 * 253 + 128 == 3670 bits total. The low 125 bits
/// of the last field element will be used for the salt.
const CHUNK_SIZE: usize = 3670;
// The Prover should have generated the proving key (before the labelsum
// protocol starts) like this:
ots.setup().unwrap();
let proving_key = ots.get_proving_key();
/// The elliptic curve on which the Poseidon hash will be computed.
pub enum Curve {
PASTA,
BN254,
}
// generate random plaintext of random size up to 2000 bytes
let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
.take(thread_rng().gen_range(0..300))
.collect();
#[cfg(test)]
mod tests {
use super::onetimesetup::OneTimeSetup;
use super::prover::Prover;
use super::verifier::Verifier;
use super::*;
use crate::tests::fixtures::e2e_test;
// Normally, the Prover is expected to obtain her binary labels by
// evaluating the garbled circuit.
// To keep this test simple, we don't evaluate the gc, but we generate
// all labels of the Verifier and give the Prover her active labels.
let bit_size = plaintext.len() * 8;
let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
let mut delta: u128 = rng.gen();
// set the last bit
delta |= 1;
for _ in 0..bit_size {
let label_zero: u128 = rng.gen();
all_binary_labels.push([label_zero, label_zero ^ delta]);
#[test]
/// Tests the whole authdecode protocol end-to-end
fn halo2_e2e_test() {
let mut prover_ots = OneTimeSetup::new();
let mut verifier_ots = OneTimeSetup::new();
// The Prover should have generated the proving key (before the labelsum
// protocol starts) like this:
prover_ots.setup();
let proving_key = prover_ots.get_proving_key();
// The Verifier should have generated the verifying key (before the labelsum
// protocol starts) like this:
verifier_ots.setup();
let verification_key = verifier_ots.get_verification_key();
let prover = Box::new(Prover::new(proving_key));
let verifier = Box::new(Verifier::new(verification_key, Curve::PASTA));
e2e_test(prover, verifier);
}
let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
let verifier = LabelsumVerifier::new(
all_binary_labels.clone(),
Box::new(verifiernode::VerifierNode {}),
);
let verifier = verifier.setup().unwrap();
let prover = LabelsumProver::new(
proving_key,
prime,
plaintext,
poseidon,
Box::new(provernode::ProverNode {}),
);
// Perform setup
let prover = prover.setup().unwrap();
// Commitment to the plaintext is sent to the Notary
let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
// Notary sends back encrypted arithm. labels.
let (cipheretexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash);
// Hash commitment to the label_sum is sent to the Notary
let (label_sum_hashes, prover) = prover
.labelsum_commitment(cipheretexts, &prover_labels)
.unwrap();
// Notary sends the arithmetic label seed
let (seed, verifier) = verifier.receive_labelsum_hashes(label_sum_hashes);
// At this point the following happens in the `committed GC` protocol:
// - the Notary reveals the GC seed
// - the User checks that the GC was created from that seed
// - the User checks that her active output labels correspond to the
// output labels derived from the seed
// - we are called with the result of the check and (if successful)
// with all the output labels
let prover = prover
.binary_labels_authenticated(true, Some(all_binary_labels))
.unwrap();
// Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
// Prover generates the proof
let (proofs, prover) = prover.create_zk_proof().unwrap();
// Notary verifies the proof
let verifier = verifier.verify_many(proofs).unwrap();
assert_eq!(
type_of(&verifier),
"labelsum::verifier::LabelsumVerifier<labelsum::verifier::VerificationSuccessfull>"
);
}
/// Unzips a slice of pairs, returning items corresponding to choice
fn choose<T: Clone>(items: &[[T; 2]], choice: &[bool]) -> Vec<T> {
assert!(items.len() == choice.len(), "arrays are different length");
items
.iter()
.zip(choice)
.map(|(items, choice)| items[*choice as usize].clone())
.collect()
}
fn type_of<T>(_: &T) -> &'static str {
std::any::type_name::<T>()
}

View File

@@ -1,26 +1,21 @@
use super::circuit::{LabelsumCircuit, K, USEFUL_ROWS};
use super::circuit::{LabelsumCircuit, CELLS_PER_ROW, K, USEFUL_ROWS};
use super::prover::PK;
use super::verifier::VK;
use halo2_proofs::plonk;
use halo2_proofs::plonk::ProvingKey;
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::commitment::Params;
use pasta_curves::pallas::Base as F;
use pasta_curves::EqAffine;
pub struct OneTimeSetup {
proving_key: Option<ProvingKey<EqAffine>>,
verification_key: Option<VerifyingKey<EqAffine>>,
proving_key: Option<PK>,
verification_key: Option<VK>,
}
#[derive(Debug)]
pub enum Error {
FileDoesNotExist,
SnarkjsError,
}
// OneTimeSetup should be run when Notary starts. It generates a proving and
// a verification keys.
// Note that currently halo2 does not support serializing the proving/verification
// keys. That's why we can't use cached keys but need to re-generate them every time.
/// OneTimeSetup generates the proving key and the verification key. It can be
/// ahead of time before the actual zk proving/verification takes place.
///
/// Note that as of Oct 2022 halo2 does not support serializing the proving/verification
/// keys. That's why we can't use cached keys but need to call this one-time setup every
/// time when we instantiate the halo2 prover/verifier.
impl OneTimeSetup {
pub fn new() -> Self {
Self {
@@ -29,27 +24,32 @@ impl OneTimeSetup {
}
}
pub fn setup(&self) -> Result<(), Error> {
pub fn setup(&mut self) {
let params: Params<EqAffine> = Params::new(K);
// we need an instance of the circuit, the exact inputs don't matter
let dummy1 = [F::from(0); 15];
let dummy2: [Vec<F>; USEFUL_ROWS] = (0..USEFUL_ROWS)
.map(|_| vec![F::from(0)])
.collect::<Vec<_>>()
.try_into()
.unwrap();
let circuit = LabelsumCircuit::new(dummy1, dummy2);
let circuit = LabelsumCircuit::new(
Default::default(),
Default::default(),
[[Default::default(); CELLS_PER_ROW]; USEFUL_ROWS],
);
// safe to unwrap, we are inputting the same params and circuit on every
// invocation
let vk = plonk::keygen_vk(&params, &circuit).unwrap();
let pk = plonk::keygen_pk(&params, vk.clone(), &circuit).unwrap();
self.proving_key = Some(pk);
self.verification_key = Some(vk);
Ok(())
self.proving_key = Some(PK {
key: pk,
params: params.clone(),
});
self.verification_key = Some(VK { key: vk, params });
}
pub fn get_proving_key(&self) -> ProvingKey<EqAffine> {
self.proving_key.unwrap()
pub fn get_proving_key(&self) -> PK {
self.proving_key.as_ref().unwrap().clone()
}
pub fn get_verification_key(&self) -> VK {
self.verification_key.as_ref().unwrap().clone()
}
}

View File

@@ -0,0 +1,130 @@
use group::ff::Field;
use halo2_gadgets::poseidon::primitives::Spec;
use halo2_gadgets::poseidon::primitives::{self as poseidon, ConstantLength};
use halo2_gadgets::poseidon::Pow5Chip;
use halo2_gadgets::poseidon::Pow5Config;
use halo2_proofs::plonk::ConstraintSystem;
use pasta_curves::pallas::Base as F;
use pasta_curves::Fp;
/// Spec for rate 15 Poseidon which halo2 uses both inside
/// the zk circuit and in the clear.
///
/// Compare it to the spec which zcash uses:
/// [halo2_gadgets::poseidon::primitives::P128Pow5T3]
#[derive(Debug)]
pub struct Spec15;
impl Spec<Fp, 16, 15> for Spec15 {
fn full_rounds() -> usize {
8
}
fn partial_rounds() -> usize {
56
}
fn sbox(val: Fp) -> Fp {
val.pow_vartime(&[5])
}
/// TODO: waiting on a definitive answer if returning 0 here is safe
/// https://github.com/zcash/halo2/issues/674
fn secure_mds() -> usize {
0
}
}
/// Spec for rate 1 Poseidon which halo2 uses both inside
/// the zk circuit and in the clear.
///
/// Compare it to the spec which zcash uses:
/// [halo2_gadgets::poseidon::primitives::P128Pow5T3]
#[derive(Debug)]
pub struct Spec1;
impl Spec<Fp, 2, 1> for Spec1 {
fn full_rounds() -> usize {
8
}
fn partial_rounds() -> usize {
56
}
fn sbox(val: Fp) -> Fp {
val.pow_vartime(&[5])
}
fn secure_mds() -> usize {
0
}
}
/// Hashes inputs with rate 15 Poseidon and returns the digest
///
/// Patterned after [halo2_gadgets::poseidon::pow5]
/// (see in that file tests::poseidon_hash())
pub fn poseidon_15(field_elements: &[F; 15]) -> F {
poseidon::Hash::<F, Spec15, ConstantLength<15>, 16, 15>::init().hash(*field_elements)
}
/// Hashes inputs with rate 1 Poseidon and returns the digest
///
/// Patterned after [halo2_gadgets::poseidon::pow5]
/// (see in that file tests::poseidon_hash())
pub fn poseidon_1(field_elements: &[F; 1]) -> F {
poseidon::Hash::<F, Spec1, ConstantLength<1>, 2, 1>::init().hash(*field_elements)
}
/// Configures the in-circuit Poseidon for rate 15 and returns the config
///
/// Patterned after [halo2_gadgets::poseidon::pow5]
/// (see in that file tests::impl Circuit for HashCircuit::configure())
pub fn configure_poseidon_rate_15<S: Spec<F, 16, 15>>(
rate: usize,
meta: &mut ConstraintSystem<F>,
) -> Pow5Config<Fp, 16, 15> {
let width = rate + 1;
let state = (0..width).map(|_| meta.advice_column()).collect::<Vec<_>>();
let partial_sbox = meta.advice_column();
let rc_a = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
meta.enable_constant(rc_b[0]);
Pow5Chip::configure::<S>(
meta,
state.try_into().unwrap(),
partial_sbox,
rc_a.try_into().unwrap(),
rc_b.try_into().unwrap(),
)
}
/// Configures the in-circuit Poseidon for rate 1 and returns the config
///
/// Patterned after [halo2_gadgets::poseidon::pow5]
/// (see in that file tests::impl Circuit for HashCircuit::configure())
pub fn configure_poseidon_rate_1<S: Spec<F, 2, 1>>(
rate: usize,
meta: &mut ConstraintSystem<F>,
) -> Pow5Config<Fp, 2, 1> {
let width = rate + 1;
let state = (0..width).map(|_| meta.advice_column()).collect::<Vec<_>>();
let partial_sbox = meta.advice_column();
let rc_a = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
meta.enable_constant(rc_b[0]);
Pow5Chip::configure::<S>(
meta,
state.try_into().unwrap(),
partial_sbox,
rc_a.try_into().unwrap(),
rc_b.try_into().unwrap(),
)
}

View File

@@ -1,51 +0,0 @@
use group::ff::Field;
/// Specs which halo2 uses to compute a Poseidon hash both inside the zk
/// circuit and in the clear.
///
///
use halo2_gadgets::poseidon::primitives::Spec;
use pasta_curves::Fp;
// Poseidon spec for 15-rate Poseidon
#[derive(Debug, Clone, Copy)]
pub struct Spec15;
impl Spec<Fp, 16, 15> for Spec15 {
fn full_rounds() -> usize {
8
}
fn partial_rounds() -> usize {
56
}
fn sbox(val: Fp) -> Fp {
val.pow_vartime(&[5])
}
fn secure_mds() -> usize {
0
}
}
// Poseidon spec for 1-rate Poseidon
#[derive(Debug, Clone, Copy)]
pub struct Spec1;
impl Spec<Fp, 2, 1> for Spec1 {
fn full_rounds() -> usize {
8
}
fn partial_rounds() -> usize {
56
}
fn sbox(val: Fp) -> Fp {
val.pow_vartime(&[5])
}
fn secure_mds() -> usize {
0
}
}

View File

@@ -1,54 +1,47 @@
use super::circuit::{LabelsumCircuit, CELLS_PER_ROW, FULL_FIELD_ELEMENTS, K, USEFUL_ROWS};
use super::poseidon_spec::{Spec1, Spec15};
use super::utils::boolvec_to_u8vec;
use super::utils::{bigint_to_f, f_to_bigint};
use crate::prover::{ProofInput, Prove, ProverError, ProvingKeyTrait};
use halo2_gadgets::poseidon::{
primitives::{self as poseidon, ConstantLength, Spec},
Hash, Pow5Chip, Pow5Config,
use super::circuit::{
LabelsumCircuit, CELLS_PER_ROW, SALT_SIZE, TOTAL_FIELD_ELEMENTS, USEFUL_ROWS,
};
use halo2_proofs::arithmetic::FieldExt;
use super::poseidon::{poseidon_1, poseidon_15};
use super::utils::{bigint_to_f, deltas_to_matrices, f_to_bigint};
use super::{CHUNK_SIZE, USEFUL_BITS};
use crate::prover::{ProofInput, Prove, ProverError};
use halo2_proofs::plonk;
use halo2_proofs::plonk::ProvingKey;
use halo2_proofs::plonk::SingleVerifier;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::transcript::Blake2bRead;
use halo2_proofs::transcript::Blake2bWrite;
use halo2_proofs::transcript::Challenge255;
use halo2_proofs::transcript::{Blake2bWrite, Challenge255};
use instant::Instant;
use num::BigUint;
use pasta_curves::pallas::Base as F;
use pasta_curves::EqAffine;
use rand::{thread_rng, Rng};
use rand::thread_rng;
// halo2's native ProvingKey can't be used without params, so we wrap
// them in one struct.
/// halo2's native ProvingKey can't be used without params, so we wrap
/// them in one struct.
#[derive(Clone)]
pub struct PK {
key: ProvingKey<EqAffine>,
params: Params<EqAffine>,
pub key: ProvingKey<EqAffine>,
pub params: Params<EqAffine>,
}
impl ProvingKeyTrait for PK {}
pub struct Prover {}
pub struct Prover {
proving_key: PK,
}
impl Prove for Prover {
fn prove(&self, input: ProofInput, proving_key: PK) -> Result<Vec<u8>, ProverError> {
// convert each delta into a field element type
let deltas: Vec<F> = input.deltas.iter().map(|d| bigint_to_f(d)).collect();
fn prove(&self, input: ProofInput) -> Result<Vec<u8>, ProverError> {
if input.deltas.len() != self.chunk_size() || input.plaintext.len() != TOTAL_FIELD_ELEMENTS
{
// this can only be caused by an error in
// `crate::prover::LabelsumProver` logic
return Err(ProverError::InternalError);
}
// to make handling simpler, we pad each set of 253 deltas
// with 3 zero deltas on the left.
let deltas: Vec<F> = deltas
.chunks(self.useful_bits())
.map(|c| {
let mut v = vec![F::from(0); 3];
v.extend(c.to_vec());
v
})
.flatten()
.collect();
// convert into matrices
let (deltas_as_rows, deltas_as_columns) =
deltas_to_matrices(&input.deltas, self.useful_bits());
// convert plaintext into field element type
let plaintext: [F; 15] = input
// convert plaintext into F type
let plaintext: [F; TOTAL_FIELD_ELEMENTS] = input
.plaintext
.iter()
.map(|bigint| bigint_to_f(bigint))
@@ -56,100 +49,53 @@ impl Prove for Prover {
.try_into()
.unwrap();
// number of chunks should be equal to USEFUL_ROWS
let all_deltas: [Vec<F>; USEFUL_ROWS] = deltas
.chunks(CELLS_PER_ROW)
.map(|c| c.to_vec())
.collect::<Vec<_>>()
.try_into()
.unwrap();
// transpose to make CELLS_PER_ROW instance columns
let input_deltas: [Vec<F>; CELLS_PER_ROW] = (0..CELLS_PER_ROW)
.map(|i| {
all_deltas
.iter()
.map(|inner| inner[i].clone())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let circuit = LabelsumCircuit::new(plaintext, all_deltas.clone());
use instant::Instant;
let now = Instant::now();
let params = proving_key;
let params: Params<EqAffine> = Params::new(K);
let vk = plonk::keygen_vk(&params, &circuit).unwrap();
let pk = plonk::keygen_pk(&params, vk.clone(), &circuit).unwrap();
println!("ProvingKey built [{:?}]", now.elapsed());
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
let mut all_inputs: Vec<&[F]> = Vec::new();
for i in 0..input_deltas.len() {
let d = input_deltas[i].as_slice();
all_inputs.push(d);
}
// arrange into the format which halo2 expects
let mut all_inputs: Vec<&[F]> = deltas_as_columns.iter().map(|v| v.as_slice()).collect();
// add another column with public inputs
let tmp = &[
bigint_to_f(&input.plaintext_hash),
bigint_to_f(&input.label_sum_hash),
bigint_to_f(&input.sum_of_zero_labels),
];
all_inputs.push(tmp);
println!("{:?} all inputs len", all_inputs.len());
let now = Instant::now();
// prepare the proving system and generate the proof:
let circuit =
LabelsumCircuit::new(plaintext, bigint_to_f(&input.salt), deltas_as_rows.into());
let params = &self.proving_key.params;
let pk = &self.proving_key.key;
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
let mut rng = thread_rng();
plonk::create_proof(
&params,
&pk,
params,
pk,
&[circuit],
&[all_inputs.as_slice()],
&mut rng,
&mut transcript,
)
.unwrap();
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
println!("Proof created [{:?}]", now.elapsed());
let proof = transcript.finalize();
let now = Instant::now();
let strategy = SingleVerifier::new(&params);
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
plonk::verify_proof(
&params,
&vk,
strategy,
&[all_inputs.as_slice()],
&mut transcript,
)
.unwrap();
//console::log_1(&format!("Proof verified {:?}", now.elapsed()).into());
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
println!("Proof verified [{:?}]", now.elapsed());
println!("Proof size [{} kB]", proof.len() as f64 / 1024.0);
Ok(proof)
}
fn useful_bits(&self) -> usize {
253
USEFUL_BITS
}
fn poseidon_rate(&self) -> usize {
15
TOTAL_FIELD_ELEMENTS
}
fn permutation_count(&self) -> usize {
@@ -157,19 +103,16 @@ impl Prove for Prover {
}
fn salt_size(&self) -> usize {
125
SALT_SIZE
}
// we have 14 field elements of 253 bits and 128 bits of the 15th field
// element, i.e. 14*253+128==3670 bits. The least 125 bits of the last
// field element will be used for the salt.
fn chunk_size(&self) -> usize {
3670
CHUNK_SIZE
}
// Hashes inputs with Poseidon and returns the digest.
/// Hashes `inputs` with Poseidon and returns the digest as `BigUint`.
fn hash(&self, inputs: &Vec<BigUint>) -> Result<BigUint, ProverError> {
let d = match inputs.len() {
let digest = match inputs.len() {
15 => {
// hash with rate-15 Poseidon
let fes: [F; 15] = inputs
@@ -178,7 +121,7 @@ impl Prove for Prover {
.collect::<Vec<_>>()
.try_into()
.unwrap();
poseidon::Hash::<F, Spec15, ConstantLength<15>, 16, 15>::init().hash(fes)
poseidon_15(&fes)
}
1 => {
// hash with rate-1 Poseidon
@@ -188,179 +131,16 @@ impl Prove for Prover {
.collect::<Vec<_>>()
.try_into()
.unwrap();
poseidon::Hash::<F, Spec1, ConstantLength<1>, 2, 1>::init().hash(fes)
poseidon_1(&fes)
}
_ => return Err(ProverError::WrongPoseidonInput),
};
Ok(f_to_bigint(&d))
Ok(f_to_bigint(&digest))
}
}
#[test]
pub fn maintest() {
use super::circuit::{LabelsumCircuit, CELLS_PER_ROW, FULL_FIELD_ELEMENTS, K, USEFUL_ROWS};
use super::utils::boolvec_to_u8vec;
use halo2_proofs::arithmetic::FieldExt;
use halo2_proofs::plonk;
use halo2_proofs::plonk::SingleVerifier;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::transcript::Blake2bRead;
use halo2_proofs::transcript::Blake2bWrite;
use halo2_proofs::transcript::Challenge255;
use pasta_curves::EqAffine;
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
// generate random plaintext to fill all cells. The first 3 bits of each
// full field element are not used, so we zero them out
const TOTAL_PLAINTEXT_SIZE: usize = CELLS_PER_ROW * USEFUL_ROWS;
let mut plaintext_bits: [bool; TOTAL_PLAINTEXT_SIZE] =
core::iter::repeat_with(|| rng.gen::<bool>())
.take(TOTAL_PLAINTEXT_SIZE)
.collect::<Vec<_>>()
.try_into()
.unwrap();
for i in 0..FULL_FIELD_ELEMENTS {
plaintext_bits[256 * i + 0] = false;
plaintext_bits[256 * i + 1] = false;
plaintext_bits[256 * i + 2] = false;
impl Prover {
pub fn new(pk: PK) -> Self {
Self { proving_key: pk }
}
// random deltas. The first 3 deltas of each set of 256 are not used, so we
// zero them out.
let mut deltas: [F; TOTAL_PLAINTEXT_SIZE] =
core::iter::repeat_with(|| F::from_u128(rng.gen::<u128>()))
.take(TOTAL_PLAINTEXT_SIZE)
.collect::<Vec<_>>()
.try_into()
.unwrap();
for i in 0..FULL_FIELD_ELEMENTS {
deltas[256 * i + 0] = F::from(0);
deltas[256 * i + 1] = F::from(0);
deltas[256 * i + 2] = F::from(0);
}
let pt_chunks = plaintext_bits.chunks(256);
// plaintext has 15 BigUint field elements - this is how the User is
// expected to call this halo2 prover
let plaintext: [BigUint; FULL_FIELD_ELEMENTS + 1] = pt_chunks
.map(|c| {
// convert each chunk of bits into a field element
BigUint::from_bytes_be(&boolvec_to_u8vec(c))
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
// number of chunks should be equal to USEFUL_ROWS
let all_deltas: [Vec<F>; USEFUL_ROWS] = deltas
.chunks(CELLS_PER_ROW)
.map(|c| c.to_vec())
.collect::<Vec<_>>()
.try_into()
.unwrap();
// transpose to make CELLS_PER_ROW instance columns
let input_deltas: [Vec<F>; CELLS_PER_ROW] = (0..CELLS_PER_ROW)
.map(|i| {
all_deltas
.iter()
.map(|inner| inner[i].clone())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let mut hash_input = [F::default(); 15];
for i in 0..hash_input.len() {
hash_input[i] = bigint_to_f(&plaintext[i]);
}
let plaintext_digest =
poseidon::Hash::<F, Spec15, ConstantLength<15>, 16, 15>::init().hash(hash_input);
// compute labelsum digest
let mut labelsum = F::from(0);
for it in plaintext_bits.iter().zip(deltas.clone()) {
let (p, d) = it;
let dot_product = F::from(*p) * d;
labelsum += dot_product;
}
let labelsum_digest =
poseidon::Hash::<F, Spec1, ConstantLength<1>, 2, 1>::init().hash([labelsum]);
let circuit = LabelsumCircuit::new(plaintext, all_deltas.clone());
// let prover = MockProver::<pallas::Base>::run(k, &circuit, input_deltas.clone()).unwrap();
// assert_eq!(prover.verify(), Ok(()));
use instant::Instant;
let now = Instant::now();
let params: Params<EqAffine> = Params::new(K);
let vk = plonk::keygen_vk(&params, &circuit).unwrap();
let pk = plonk::keygen_pk(&params, vk.clone(), &circuit).unwrap();
//console::log_1(&format!("ProvingKey built {:?}", now.elapsed()).into());
println!("ProvingKey built [{:?}]", now.elapsed());
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
let mut all_inputs: Vec<&[F]> = Vec::new();
for i in 0..input_deltas.len() {
let d = input_deltas[i].as_slice();
all_inputs.push(d);
}
let tmp = &[plaintext_digest, labelsum_digest];
all_inputs.push(tmp);
println!("{:?} all inputs len", all_inputs.len());
let now = Instant::now();
plonk::create_proof(
&params,
&pk,
&[circuit],
&[all_inputs.as_slice()],
&mut rng,
&mut transcript,
)
.unwrap();
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
println!("Proof created [{:?}]", now.elapsed());
let proof = transcript.finalize();
let now = Instant::now();
let strategy = SingleVerifier::new(&params);
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
plonk::verify_proof(
&params,
&vk,
strategy,
&[all_inputs.as_slice()],
&mut transcript,
)
.unwrap();
//console::log_1(&format!("Proof verified {:?}", now.elapsed()).into());
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
println!("Proof verified [{:?}]", now.elapsed());
println!("Proof size [{} kB]", proof.len() as f64 / 1024.0);
}
#[test]
fn test() {
use num::FromPrimitive;
let two = BigUint::from_u8(2).unwrap();
let pow_2_64 = two.pow(64);
let pow_2_128 = two.pow(128);
let pow_2_192 = two.pow(192);
}

View File

@@ -1,16 +1,22 @@
use super::circuit::{CELLS_PER_ROW, USEFUL_ROWS};
use crate::utils::{boolvec_to_u8vec, u8vec_to_boolvec};
use crate::Delta;
use halo2_proofs::arithmetic::FieldExt;
use num::{BigUint, FromPrimitive};
use pasta_curves::Fp as F;
// Decomposes a `BigUint` into bits and returns the bits in BE bit order,
// left padding them with zeroes to the size of 256.
pub fn bigint_to_bits(bigint: BigUint) -> [bool; 256] {
/// Decomposes a `BigUint` into bits and returns the bits in MSB-first bit order,
/// left padding them with zeroes to the size of 256.
pub fn bigint_to_256bits(bigint: BigUint) -> [bool; 256] {
let bits = u8vec_to_boolvec(&bigint.to_bytes_be());
let mut bits256 = vec![false; 256];
bits256[256 - bits.len()..].copy_from_slice(&bits);
bits256.try_into().unwrap()
}
/// Converts a `BigUint` into an field element type.
/// The assumption is that `bigint` was sanitized earlier and is not larger
/// than [crate::verifier::Verify::field_size]
pub fn bigint_to_f(bigint: &BigUint) -> F {
let le = bigint.to_bytes_le();
let mut wide = [0u8; 64];
@@ -18,13 +24,14 @@ pub fn bigint_to_f(bigint: &BigUint) -> F {
F::from_bytes_wide(&wide)
}
/// Converts `F` into a `BigUint` type
pub fn f_to_bigint(f: &F) -> BigUint {
let tmp: [u8; 32] = f.try_into().unwrap();
BigUint::from_bytes_le(&tmp)
}
// Splits up 256 bits into 4 limbs, shifts each limb left
// and returns the shifted limbs as `BigUint`s.
/// Splits up 256 bits into 4 limbs, shifts each limb left
/// and returns the shifted limbs as `BigUint`s.
pub fn bits_to_limbs(bits: [bool; 256]) -> [BigUint; 4] {
// break up the field element into 4 64-bit limbs
// the limb at index 0 is the high limb
@@ -35,8 +42,10 @@ pub fn bits_to_limbs(bits: [bool; 256]) -> [BigUint; 4] {
.try_into()
.unwrap();
// shift each limb to the left
// shift each limb to the left:
let two = BigUint::from_u8(2).unwrap();
// how many bits to shift each limb by
let shift_by: [BigUint; 4] = [192, 128, 64, 0]
.iter()
.map(|s| two.pow(*s))
@@ -52,31 +61,82 @@ pub fn bits_to_limbs(bits: [bool; 256]) -> [BigUint; 4] {
.unwrap()
}
#[inline]
pub fn u8vec_to_boolvec(v: &[u8]) -> Vec<bool> {
let mut bv = Vec::with_capacity(v.len() * 8);
for byte in v.iter() {
for i in 0..8 {
bv.push(((byte >> (7 - i)) & 1) != 0);
}
}
bv
/// Converts a vec of padded deltas into a matrix of rows and a matrix of
/// columns and returns them.
pub fn deltas_to_matrices(
deltas: &Vec<Delta>,
useful_bits: usize,
) -> (
[[F; CELLS_PER_ROW]; USEFUL_ROWS],
[[F; USEFUL_ROWS]; CELLS_PER_ROW],
) {
let deltas = convert_and_pad_deltas(deltas, useful_bits);
let deltas_as_rows = deltas_to_matrix_of_rows(&deltas, useful_bits);
let deltas_as_columns = transpose_rows(&deltas_as_rows);
(deltas_as_rows, deltas_as_columns)
}
// Convert bits into bytes. The bits will be left-padded with zeroes to the
// multiple of 8.
#[inline]
pub fn boolvec_to_u8vec(bv: &[bool]) -> Vec<u8> {
let rem = bv.len() % 8;
let first_byte_bitsize = if rem == 0 { 8 } else { rem };
let offset = if rem == 0 { 0 } else { 1 };
let mut v = vec![0u8; bv.len() / 8 + offset];
// implicitely left-pad the first byte with zeroes
for (i, b) in bv[0..first_byte_bitsize].iter().enumerate() {
v[i / 8] |= (*b as u8) << (first_byte_bitsize - 1 - i);
}
for (i, b) in bv[first_byte_bitsize..].iter().enumerate() {
v[1 + i / 8] |= (*b as u8) << (7 - (i % 8));
}
v
/// To make handling inside the circuit simpler, we pad each chunk (except for
/// the last one) of deltas with zero values on the left to the size 256.
/// Note that the last chunk (corresponding to the 15th field element) will
/// contain only 128 deltas, so we do NOT pad it.
///
/// Returns padded deltas
fn convert_and_pad_deltas(deltas: &Vec<Delta>, useful_bits: usize) -> Vec<F> {
// convert deltas into F type
let deltas: Vec<F> = deltas.iter().map(|d| bigint_to_f(d)).collect();
deltas
.chunks(useful_bits)
.enumerate()
.map(|(i, c)| {
if i < 14 {
let mut v = vec![F::from(0); 256 - c.len()];
v.extend(c.to_vec());
v
} else {
c.to_vec()
}
})
.flatten()
.collect()
}
/// Converts a vec of padded deltas into a matrix of rows and returns it.
fn deltas_to_matrix_of_rows(
deltas: &Vec<F>,
useful_bits: usize,
) -> ([[F; CELLS_PER_ROW]; USEFUL_ROWS]) {
deltas
.chunks(CELLS_PER_ROW)
.map(|c| c.try_into().unwrap())
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
/// Transposes a matrix of rows.
fn transpose_rows(matrix: &[[F; CELLS_PER_ROW]; USEFUL_ROWS]) -> [[F; USEFUL_ROWS]; CELLS_PER_ROW] {
(0..CELLS_PER_ROW)
.map(|i| {
matrix
.iter()
.map(|inner| inner[i].clone().try_into().unwrap())
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
/// Converts a vec of deltas into a matrix of columns and returns it.
fn deltas_to_matrix_of_columns(
deltas: &Vec<F>,
useful_bits: usize,
) -> [[F; USEFUL_ROWS]; CELLS_PER_ROW] {
transpose_rows(&deltas_to_matrix_of_rows(deltas, useful_bits))
}

View File

@@ -1,63 +1,89 @@
use crate::verifier::{VerifierError, Verify};
use json::{array, object, stringify, stringify_pretty, JsonValue};
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
use std::env::temp_dir;
use std::fs;
use std::path::Path;
use std::process::{Command, Output};
use uuid::Uuid;
use super::utils::{bigint_to_f, deltas_to_matrices};
use super::{Curve, CHUNK_SIZE, USEFUL_BITS};
use crate::verifier::{VerificationInput, VerifierError, Verify};
use halo2_proofs::plonk;
use halo2_proofs::plonk::SingleVerifier;
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::transcript::Blake2bRead;
use halo2_proofs::transcript::Challenge255;
use instant::Instant;
use pasta_curves::pallas::Base as F;
use pasta_curves::EqAffine;
pub struct Verifier {}
/// halo2's native [halo2::VerifyingKey] can't be used without params, so we wrap
/// them in one struct.
#[derive(Clone)]
pub struct VK {
pub key: VerifyingKey<EqAffine>,
pub params: Params<EqAffine>,
}
pub struct Verifier {
verification_key: VK,
curve: Curve,
}
impl Verifier {
pub fn new(vk: VK, curve: Curve) -> Self {
Self {
verification_key: vk,
curve,
}
}
}
impl Verify for Verifier {
fn verify(
&self,
proof: Vec<u8>,
deltas: Vec<String>,
plaintext_hash: BigUint,
labelsum_hash: BigUint,
zero_sum: BigUint,
) -> Result<bool, VerifierError> {
// public.json is a flat array
let mut public_json: Vec<String> = Vec::new();
public_json.push(plaintext_hash.to_string());
public_json.push(labelsum_hash.to_string());
public_json.extend::<Vec<String>>(deltas);
public_json.push(zero_sum.to_string());
let s = stringify(JsonValue::from(public_json.clone()));
fn verify(&self, input: VerificationInput) -> Result<bool, VerifierError> {
let params = &self.verification_key.params;
let vk = &self.verification_key.key;
// write into temp files and delete the files after verification
let mut path1 = temp_dir();
let mut path2 = temp_dir();
path1.push(format!("public.json.{}", Uuid::new_v4()));
path2.push(format!("proof.json.{}", Uuid::new_v4()));
fs::write(path1.clone(), s).expect("Unable to write file");
fs::write(path2.clone(), proof).expect("Unable to write file");
let strategy = SingleVerifier::new(&params);
let proof = input.proof;
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
let output = Command::new("node")
.args([
"verify.mjs",
path1.to_str().unwrap(),
path2.to_str().unwrap(),
])
.output();
fs::remove_file(path1).expect("Unable to remove file");
fs::remove_file(path2).expect("Unable to remove file");
// convert deltas into a matrix which halo2 expects
let (_, deltas_as_columns) = deltas_to_matrices(&input.deltas, self.useful_bits());
check_output(&output)?;
if !output.unwrap().status.success() {
return Ok(false);
let mut all_inputs: Vec<&[F]> = deltas_as_columns.iter().map(|v| v.as_slice()).collect();
// add another column with public inputs
let tmp = &[
bigint_to_f(&input.plaintext_hash),
bigint_to_f(&input.label_sum_hash),
bigint_to_f(&input.sum_of_zero_labels),
];
all_inputs.push(tmp);
let now = Instant::now();
// perform the actual verification
let res = plonk::verify_proof(
params,
vk,
strategy,
&[all_inputs.as_slice()],
&mut transcript,
);
println!("Proof verified [{:?}]", now.elapsed());
if res.is_err() {
return Err(VerifierError::VerificationFailed);
} else {
Ok(true)
}
Ok(true)
}
}
fn check_output(output: &Result<Output, std::io::Error>) -> Result<(), VerifierError> {
if output.is_err() {
return Err(VerifierError::SnarkjsError);
fn field_size(&self) -> usize {
match self.curve {
Curve::PASTA => 255,
Curve::BN254 => 254,
_ => panic!("a new curve was added. Add its field size here."),
}
}
if !output.as_ref().unwrap().status.success() {
return Err(VerifierError::SnarkjsError);
fn useful_bits(&self) -> usize {
USEFUL_BITS
}
fn chunk_size(&self) -> usize {
CHUNK_SIZE
}
Ok(())
}

View File

@@ -1,85 +1,86 @@
use super::boolvec_to_u8vec;
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
use rand::SeedableRng;
use rand::{thread_rng, Rng};
use super::utils::bits_to_bigint;
use num::BigUint;
use rand::RngCore;
use rand::{thread_rng, Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
// The PRG for generating arithmetic labels
/// The PRG for generating arithmetic labels.
type Prg = ChaCha20Rng;
/// The seed from which to generate the arithmetic labels.
pub type Seed = [u8; 32];
// The arithmetic label
// The arithmetic label.
type Label = BigUint;
/// A pair of labels: the first one encodes the value 0, the second one encodes
/// the value 1.
pub type LabelPair = [Label; 2];
/// typestates are used to prevent generate() from being called multiple times
/// on the same instance of LabelSeed
pub trait State {}
pub struct Generate {
seed: Seed,
}
pub struct Finished {}
impl State for Generate {}
impl State for Finished {}
pub struct LabelGenerator<S = Generate>
where
S: State,
{
state: S,
}
pub struct LabelGenerator {}
impl LabelGenerator {
pub fn new() -> LabelGenerator<Generate> {
LabelGenerator {
state: Generate {
seed: thread_rng().gen::<Seed>(),
},
}
/// Generates a seed and then generates `count` arithmetic label pairs
/// of bitsize `label_size` from that seed. Returns the labels and the seed.
pub fn generate(count: usize, label_size: usize) -> (Vec<LabelPair>, Seed) {
let seed = thread_rng().gen::<Seed>();
let pairs = LabelGenerator::generate_from_seed(count, label_size, seed);
(pairs, seed)
}
pub fn new_from_seed(seed: Seed) -> LabelGenerator<Generate> {
LabelGenerator {
state: Generate { seed },
}
// Generates `count` arithmetic label pairs of bitsize `label_size` from a
// seed and returns the labels.
pub fn generate_from_seed(count: usize, label_size: usize, seed: Seed) -> Vec<LabelPair> {
let prg = Prg::from_seed(seed);
LabelGenerator::generate_from_prg(count, label_size, Box::new(prg))
}
}
impl LabelGenerator<Generate> {
/// Generates `count` arithmetic label pairs of size `label_size`. Returns
/// the generated label pairs and the seed.
pub fn generate(
self,
/// Generates `count` arithmetic label pairs of bitsize `label_size` using a PRG.
/// Returns the generated label pairs.
fn generate_from_prg(
count: usize,
label_size: usize,
) -> (Vec<LabelPair>, Seed, LabelGenerator<Finished>) {
// To keep the handling simple, we want to avoid a negative delta, that's why
// W_0 and delta must be (label_size - 1)-bit values and W_1 will be
// set to W_0 + delta
let mut prg = Prg::from_seed(self.state.seed);
let label_pairs: Vec<LabelPair> = (0..count)
mut prg: Box<dyn RngCore>,
) -> Vec<LabelPair> {
(0..count)
.map(|_| {
let zero_label: Vec<bool> = core::iter::repeat_with(|| prg.gen::<bool>())
.take(label_size - 1)
.collect();
let zero_label = BigUint::from_bytes_be(&boolvec_to_u8vec(&zero_label));
// To keep the handling simple, we want to avoid a negative delta, that's why
// W_0 and delta must be (label_size - 1)-bit values and W_1 will be
// set to W_0 + delta
let zero_label = bits_to_bigint(
&core::iter::repeat_with(|| prg.gen::<bool>())
.take(label_size - 1)
.collect::<Vec<_>>(),
);
let delta: Vec<bool> = core::iter::repeat_with(|| prg.gen::<bool>())
.take(label_size - 1)
.collect();
let delta = BigUint::from_bytes_be(&boolvec_to_u8vec(&delta));
let delta = bits_to_bigint(
&core::iter::repeat_with(|| prg.gen::<bool>())
.take(label_size - 1)
.collect::<Vec<_>>(),
);
let one_label = zero_label.clone() + delta.clone();
[zero_label, one_label]
})
.collect();
(
label_pairs,
self.state.seed,
LabelGenerator { state: Finished {} },
)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::LabelGenerator;
use num::BigUint;
use rand::rngs::mock::StepRng;
#[test]
fn test_label_generator() {
// PRG which always returns bit 1
let prg = StepRng::new(u64::MAX, 0);
// zero_label and delta should be 511 (bit 1 repeated 9 times), one_label
// should be 511+511=1022
let result = LabelGenerator::generate_from_prg(10, 10, Box::new(prg));
let expected = (0..10)
.map(|_| [BigUint::from(511u128), BigUint::from(1022u128)])
.collect::<Vec<_>>();
assert_eq!(expected, result);
}
}

View File

@@ -1,128 +1,72 @@
use aes::{Aes128, NewBlockCipher};
use cipher::{consts::U16, generic_array::GenericArray, BlockCipher, BlockEncrypt};
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
use poseidon::Poseidon;
use rand::{thread_rng, Rng};
use sha2::{Digest, Sha256};
use num::BigUint;
pub mod halo2_backend;
pub mod label;
pub mod onetimesetup;
pub mod poseidon;
pub mod prover;
//pub mod provernode;
pub mod snarkjs_backend;
pub mod utils;
pub mod verifier;
pub mod verifiernode;
// bn254 prime 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
// in decimal 21888242871839275222246405745257275088548364400416034343698204186575808495617
const BN254_PRIME: &str =
"21888242871839275222246405745257275088548364400416034343698204186575808495617";
/// How many field elements our Poseidon hash consumes for one permutation.
const POSEIDON_RATE: usize = 16;
/// How many permutations our circom circuit supports. One permutation consumes
/// POSEIDON_WIDTH field elements.
const PERMUTATION_COUNT: usize = 1;
/// The bitsize of an arithmetic label. MUST be > 40 to give statistical
/// security against the Prover guessing the label. For a 254-bit field,
/// the bitsize > 96 would require 2 field elements for the
/// salted labelsum instead of 1.
const ARITHMETIC_LABEL_SIZE: usize = 96;
/// The maximum size (in bits) of one chunk of plaintext that we support
/// for 254-bit fields. Calculated as 2^{Field_size - 1 - 128 - 96}, where
/// 128 is the size of salt and 96 is the size of the arithmetic label.
const MAX_CHUNK_SIZE: usize = 1 << 29;
/// The maximum supported size (in bits) of one [Chunk] of plaintext.
/// Should not exceed 2^{ [prover::Prove::useful_bits] - [prover::Prove::salt_size]
/// - [ARITHMETIC_LABEL_SIZE] }.
/// 2^20 should suffice for most use cases.
const MAX_CHUNK_SIZE: usize = 1 << 20;
pub fn random_bigint(bitsize: usize) -> BigUint {
assert!(bitsize <= 128);
let r: [u8; 16] = thread_rng().gen();
// take only those bits which we need
BigUint::from_bytes_be(&boolvec_to_u8vec(&u8vec_to_boolvec(&r)[0..bitsize]))
}
/// The maximum supported amount of plaintext [Chunk]s ( which equals to the
/// amount of zk proofs). Having too many zk proofs may be a DOS vector
/// against the Notary who is the verifier of zk proofs.
const MAX_CHUNK_COUNT: usize = 128;
#[inline]
pub fn u8vec_to_boolvec(v: &[u8]) -> Vec<bool> {
let mut bv = Vec::with_capacity(v.len() * 8);
for byte in v.iter() {
for i in 0..8 {
bv.push(((byte >> (7 - i)) & 1) != 0);
}
}
bv
}
/// The decoded output labels of the garbled circuit. In other words, this is
/// the plaintext output resulting from the evaluation of a garbled circuit.
type Plaintext = Vec<u8>;
// Convert bits into bytes. The bits will be left-padded with zeroes to the
// multiple of 8.
#[inline]
pub fn boolvec_to_u8vec(bv: &[bool]) -> Vec<u8> {
let rem = bv.len() % 8;
let first_byte_bitsize = if rem == 0 { 8 } else { rem };
let offset = if rem == 0 { 0 } else { 1 };
let mut v = vec![0u8; bv.len() / 8 + offset];
// implicitely left-pad the first byte with zeroes
for (i, b) in bv[0..first_byte_bitsize].iter().enumerate() {
v[i / 8] |= (*b as u8) << (first_byte_bitsize - 1 - i);
}
for (i, b) in bv[first_byte_bitsize..].iter().enumerate() {
v[1 + i / 8] |= (*b as u8) << (7 - (i % 8));
}
v
}
/// A chunk of [Plaintext]. The amount of vec elements equals
/// [Prove::poseidon_rate] * [Prove::permutation_count]. Each vec element
/// is an "Elliptic curve field element" into which [Prove::useful_bits] bits
/// of [Plaintext] is packed.
/// The chunk does NOT contain the [Salt].
type Chunk = Vec<BigUint>;
pub fn sha256(data: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().into()
}
/// Before hashing a [Chunk], it is salted by shifting its last element to the
/// left by [Prove::salt_size] and placing the salt into the low bits.
/// This same salt is also used to salt the sum of all the labels corresponding
/// to the [Chunk].
/// Without the salt, a hash of plaintext with low entropy could be brute-forced.
type Salt = BigUint;
/// Encrypts each arithmetic label using a corresponding binary label as a key
/// and returns ciphertexts in an order based on binary label's pointer bit (LSB).
pub fn encrypt_arithmetic_labels(
alabels: &Vec<[BigUint; 2]>,
blabels: &Vec<[u128; 2]>,
) -> Vec<[Vec<u8>; 2]> {
assert!(alabels.len() == blabels.len());
/// A Poseidon hash digest of a [Salt]ed [Chunk]. This is an EC field element.
type PlaintextHash = BigUint;
blabels
.iter()
.zip(alabels)
.map(|(bin_pair, arithm_pair)| {
let zero_key = Aes128::new_from_slice(&bin_pair[0].to_be_bytes()).unwrap();
let one_key = Aes128::new_from_slice(&bin_pair[1].to_be_bytes()).unwrap();
let mut label0 = [0u8; 16];
let mut label1 = [0u8; 16];
let ap0 = arithm_pair[0].to_bytes_be();
let ap1 = arithm_pair[1].to_bytes_be();
// pad with zeroes on the left
label0[16 - ap0.len()..].copy_from_slice(&ap0);
label1[16 - ap1.len()..].copy_from_slice(&ap1);
let mut label0: GenericArray<u8, U16> = GenericArray::from(label0);
let mut label1: GenericArray<u8, U16> = GenericArray::from(label1);
zero_key.encrypt_block(&mut label0);
one_key.encrypt_block(&mut label1);
// ciphertext 0 and ciphertext 1
let ct0 = label0.to_vec();
let ct1 = label1.to_vec();
// place ar. labels based on the point and permute bit of bin. label 0
if (bin_pair[0] & 1) == 0 {
[ct0, ct1]
} else {
[ct1, ct0]
}
})
.collect()
}
/// A Poseidon hash digest of a [Salt]ed arithmetic sum of arithmetic labels
/// corresponding to the [Chunk]. This is an EC field element.
type LabelsumHash = BigUint;
/// An arithmetic sum of all "zero" arithmetic labels ( those are the labels
/// which encode the bit value 0) corresponding to one [Chunk].
type ZeroSum = BigUint;
/// An arithmetic difference between the arithmetic label "one" and the
/// arithmetic label "zero".
type Delta = BigUint;
/// A serialized proof proving that a Poseidon hash is the result of hashing a
/// salted [Chunk].
type Proof = Vec<u8>;
#[cfg(test)]
mod tests {
use crate::onetimesetup::OneTimeSetup;
use super::*;
use num::{BigUint, FromPrimitive};
use prover::LabelsumProver;
use rand::{thread_rng, Rng, RngCore};
use rand::{thread_rng, Rng};
use verifier::LabelsumVerifier;
/// Unzips a slice of pairs, returning items corresponding to choice
@@ -139,98 +83,85 @@ mod tests {
std::any::type_name::<T>()
}
// #[test]
// fn e2e_test() {
// let prime = String::from(BN254_PRIME).parse::<BigUint>().unwrap();
// let mut rng = thread_rng();
pub mod fixtures {
use super::utils::*;
use super::*;
use crate::prover::Prove;
use crate::verifier::Verify;
// // OneTimeSetup is a no-op if the setup has been run before
// let ots = OneTimeSetup::new();
// ots.setup().unwrap();
/// Accepts a concrete Prover and Verifier and runs the whole labelsum
/// protocol end-to-end.
pub fn e2e_test(prover: Box<dyn Prove>, verifier: Box<dyn Verify>) {
let mut rng = thread_rng();
// // Poseidon need to be instantiated once and then passed to all instances
// // of Prover
// let poseidon = Poseidon::new();
// generate random plaintext of random size up to 2000 bytes
let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
.take(thread_rng().gen_range(0..1000))
.collect();
// // The Prover should have received the proving key (before the labelsum
// // protocol starts) like this:
// let proving_key = ots.get_proving_key().unwrap();
// Normally, the Prover is expected to obtain her binary labels by
// evaluating the garbled circuit.
// To keep this test simple, we don't evaluate the gc, but we generate
// all labels of the Verifier and give the Prover her active labels.
let bit_size = plaintext.len() * 8;
let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
let mut delta: u128 = rng.gen();
// set the last bit
delta |= 1;
for _ in 0..bit_size {
let label_zero: u128 = rng.gen();
all_binary_labels.push([label_zero, label_zero ^ delta]);
}
let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
// // generate random plaintext of random size up to 2000 bytes
// let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
// .take(thread_rng().gen_range(0..2000))
// .collect();
let verifier = LabelsumVerifier::new(all_binary_labels.clone(), verifier);
// // Normally, the Prover is expected to obtain her binary labels by
// // evaluating the garbled circuit.
// // To keep this test simple, we don't evaluate the gc, but we generate
// // all labels of the Verifier and give the Prover her active labels.
// let bit_size = plaintext.len() * 8;
// let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
// let mut delta: u128 = rng.gen();
// // set the last bit
// delta |= 1;
// for _ in 0..bit_size {
// let label_zero: u128 = rng.gen();
// all_binary_labels.push([label_zero, label_zero ^ delta]);
// }
// let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
let verifier = verifier.setup().unwrap();
// let verifier = LabelsumVerifier::new(
// all_binary_labels.clone(),
// Box::new(verifiernode::VerifierNode {}),
// );
let prover = LabelsumProver::new(plaintext, prover);
// let verifier = verifier.setup().unwrap();
// Perform setup
let prover = prover.setup().unwrap();
// let prover = LabelsumProver::new(
// proving_key,
// prime,
// plaintext,
// poseidon,
// Box::new(provernode::ProverNode {}),
// );
// Commitment to the plaintext is sent to the Notary
let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
// // Perform setup
// let prover = prover.setup().unwrap();
// Notary sends back encrypted arithm. labels.
let (ciphertexts, verifier) =
verifier.receive_plaintext_hashes(plaintext_hash).unwrap();
// // Commitment to the plaintext is sent to the Notary
// let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
// Hash commitment to the label_sum is sent to the Notary
let (label_sum_hashes, prover) = prover
.labelsum_commitment(ciphertexts, &prover_labels)
.unwrap();
// // Notary sends back encrypted arithm. labels.
// let (cipheretexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash);
// Notary sends the arithmetic label seed
let (seed, verifier) = verifier.receive_labelsum_hashes(label_sum_hashes).unwrap();
// // Hash commitment to the label_sum is sent to the Notary
// let (label_sum_hashes, prover) = prover
// .labelsum_commitment(cipheretexts, &prover_labels)
// .unwrap();
// At this point the following happens in the `committed GC` protocol:
// - the Notary reveals the GC seed
// - the User checks that the GC was created from that seed
// - the User checks that her active output labels correspond to the
// output labels derived from the seed
// - we are called with the result of the check and (if successful)
// with all the output labels
// // Notary sends the arithmetic label seed
// let (seed, verifier) = verifier.receive_labelsum_hashes(label_sum_hashes);
let prover = prover
.binary_labels_authenticated(true, Some(all_binary_labels))
.unwrap();
// // At this point the following happens in the `committed GC` protocol:
// // - the Notary reveals the GC seed
// // - the User checks that the GC was created from that seed
// // - the User checks that her active output labels correspond to the
// // output labels derived from the seed
// // - we are called with the result of the check and (if successful)
// // with all the output labels
// Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
// let prover = prover
// .binary_labels_authenticated(true, Some(all_binary_labels))
// .unwrap();
// Prover generates the proof
let (proofs, prover) = prover.create_zk_proofs().unwrap();
// // Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
// let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
// // Prover generates the proof
// let (proofs, prover) = prover.create_zk_proof().unwrap();
// // Notary verifies the proof
// let verifier = verifier.verify_many(proofs).unwrap();
// assert_eq!(
// type_of(&verifier),
// "labelsum::verifier::LabelsumVerifier<labelsum::verifier::VerificationSuccessfull>"
// );
// }
// Notary verifies the proof
let verifier = verifier.verify_many(proofs).unwrap();
assert_eq!(
type_of(&verifier),
"labelsum::verifier::LabelsumVerifier<labelsum::verifier::VerificationSuccessfull>"
);
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1 +1,32 @@
pub mod onetimesetup;
pub mod poseidon;
pub mod provernode;
pub mod verifiernode;
#[cfg(test)]
mod tests {
use super::onetimesetup::OneTimeSetup;
use super::provernode::Prover;
use super::verifiernode::Verifier;
use crate::tests::fixtures::e2e_test;
#[test]
fn snarkjs_e2e_test() {
let prover_ots = OneTimeSetup::new();
let verifier_ots = OneTimeSetup::new();
// The Prover should have generated the proving key (before the labelsum
// protocol starts) like this:
prover_ots.setup().unwrap();
let proving_key = prover_ots.get_proving_key().unwrap();
// The Verifier should have generated the verifying key (before the labelsum
// protocol starts) like this:
verifier_ots.setup().unwrap();
let verification_key = verifier_ots.get_verification_key().unwrap();
let prover = Box::new(Prover::new(proving_key));
let verifier = Box::new(Verifier::new(verification_key));
e2e_test(prover, verifier);
}
}

View File

@@ -41,16 +41,16 @@ impl OneTimeSetup {
pub fn setup(&self) -> Result<(), Error> {
// check if files which we ship are present
if !Path::new("powersOfTau28_hez_final_14.ptau").exists()
|| !Path::new("circuit.r1cs").exists()
if !Path::new("circom/powersOfTau28_hez_final_14.ptau").exists()
|| !Path::new("circom/circuit.r1cs").exists()
{
return Err(Error::FileDoesNotExist);
}
// check if any of the files hasn't been generated. If so, regenerate
// all files.
if !Path::new("circuit_0000.zkey").exists()
|| !Path::new("circuit_final.zkey.notary").exists()
|| !Path::new("verification_key.json").exists()
if !Path::new("circom/circuit_0000.zkey").exists()
|| !Path::new("circom/circuit_final.zkey.notary").exists()
|| !Path::new("circom/verification_key.json").exists()
{
let entropy = self.generate_entropy();
//return self.regenerate1(entropy);
@@ -62,12 +62,22 @@ impl OneTimeSetup {
// Returns the already existing proving key
pub fn get_proving_key(&self) -> Result<Vec<u8>, Error> {
let path = Path::new("circuit_final.zkey.notary");
let path = Path::new("circom/circuit_final.zkey.notary");
if !path.exists() {
return Err(Error::FileDoesNotExist);
}
let proof = fs::read(path.clone()).unwrap();
Ok(proof)
let key = fs::read(path.clone()).unwrap();
Ok(key)
}
// Returns the already existing verification key
pub fn get_verification_key(&self) -> Result<Vec<u8>, Error> {
let path = Path::new("circom/verification_key.json");
if !path.exists() {
return Err(Error::FileDoesNotExist);
}
let key = fs::read(path.clone()).unwrap();
Ok(key)
}
// this will work only if snarkjs is in the PATH
@@ -76,9 +86,9 @@ impl OneTimeSetup {
.args([
"groth16",
"setup",
"circuit.r1cs",
"powersOfTau28_hez_final_14.ptau",
"circuit_0000.zkey",
"circom/circuit.r1cs",
"circom/powersOfTau28_hez_final_14.ptau",
"circom/circuit_0000.zkey",
])
.output();
self.check_output(output)?;
@@ -87,8 +97,8 @@ impl OneTimeSetup {
.args([
"zkey",
"contribute",
"circuit_0000.zkey",
"circuit_final.zkey.notary",
"circom/circuit_0000.zkey",
"circom/circuit_final.zkey.notary",
&(String::from("-e=\"") + &entropy + &String::from("\"")),
])
.output();
@@ -99,8 +109,8 @@ impl OneTimeSetup {
"zkey",
"export",
"verificationkey",
"circuit_final.zkey.notary",
"verification_key.json",
"circom/circuit_final.zkey.notary",
"circom/verification_key.json",
])
.output();
self.check_output(output)?;
@@ -111,19 +121,9 @@ impl OneTimeSetup {
// call a js wrapper which does what regenerate1() above does
fn regenerate2(&self, entropy: String) -> Result<(), Error> {
let output = Command::new("node")
.args(["onetimesetup.mjs", &entropy])
.args(["circom/onetimesetup.mjs", &entropy])
.output();
self.check_output(output)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test() {
let mut ots = OneTimeSetup::new();
ots.setup().unwrap();
}
}

View File

@@ -1,20 +1,15 @@
use std::str::FromStr;
use ark_bn254::Fr as F;
use ark_ff::{One, PrimeField};
use ark_sponge::poseidon::{PoseidonConfig, PoseidonSponge};
use ark_sponge::CryptographicSponge;
use ark_sponge::FieldBasedCryptographicSponge;
use ark_sponge::{CryptographicSponge, DuplexSpongeMode};
use lazy_static::lazy_static;
use num::{BigUint, FromPrimitive, Num, ToPrimitive, Zero};
use num::{BigUint, Num};
use regex::Regex;
use std::fs::File;
use std::io::prelude::*;
#[derive(Debug)]
enum Error {
Error1,
}
/// additive round keys for a specific Poseidon rate
/// outer vec length equals to total (partial + full) round count
/// inner vec length equals rate + capacity (our Poseidon's capacity is fixed at 1)
@@ -40,7 +35,7 @@ pub struct Poseidon {
impl Poseidon {
pub fn new() -> Poseidon {
let (arks, mdss) = setup().unwrap();
let (arks, mdss) = setup();
Poseidon { arks, mdss }
}
@@ -75,8 +70,8 @@ impl Poseidon {
}
}
fn setup() -> Result<(Vec<Ark>, Vec<Mds>), Error> {
let mut file = File::open("poseidon_constants_old.circom").unwrap();
fn setup() -> (Vec<Ark>, Vec<Mds>) {
let mut file = File::open("circom/poseidon_constants_old.circom").unwrap();
let mut contents = String::new();
file.read_to_string(&mut contents).unwrap();
@@ -124,13 +119,5 @@ fn setup() -> Result<(Vec<Ark>, Vec<Mds>), Error> {
}
// we should have consumed all elements
assert!(v.len() == offset);
Ok((arks, mdss))
}
#[test]
fn test_stuff() {
let p = Poseidon::new();
let input = vec![BigUint::from_u8(0).unwrap(); 16];
let out = p.hash(&input);
print!("{:?} out:", out);
(arks, mdss)
}

View File

@@ -1,16 +1,85 @@
use super::poseidon::Poseidon;
use crate::prover::ProofInput;
use crate::prover::{Prove, ProverError};
use json::{object, stringify_pretty};
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
use std::env::temp_dir;
use std::fs;
use std::process::{Command, Output};
use uuid::Uuid;
pub struct ProverNode {}
pub struct Prover {
proving_key: Vec<u8>,
poseidon: Poseidon,
}
impl Prover {
pub fn new(proving_key: Vec<u8>) -> Self {
Self {
proving_key,
poseidon: Poseidon::new(),
}
}
// Creates inputs in the "input.json" format
fn create_proof_inputs(&self, input: ProofInput) -> String {
// convert each field element of plaintext into a string
let plaintext: Vec<String> = input
.plaintext
.iter()
.map(|bigint| bigint.to_string())
.collect();
// convert all deltas to strings
let deltas_str: Vec<String> = input.deltas.iter().map(|v| v.to_string()).collect();
// split deltas into groups corresponding to the field elements
// of our Poseidon circuit
let deltas_fes: Vec<&[String]> = deltas_str.chunks(self.useful_bits()).collect();
// prepare input.json
let input = object! {
plaintext_hash: input.plaintext_hash.to_string(),
label_sum_hash: input.label_sum_hash.to_string(),
sum_of_zero_labels: input.sum_of_zero_labels.to_string(),
plaintext: plaintext,
salt: input.salt.to_string(),
delta: deltas_fes[0..deltas_fes.len()-1],
// last field element's deltas are a separate input
delta_last: deltas_fes[deltas_fes.len()-1]
};
stringify_pretty(input, 4)
}
}
impl Prove for Prover {
fn useful_bits(&self) -> usize {
253
}
fn poseidon_rate(&self) -> usize {
16
}
fn permutation_count(&self) -> usize {
1
}
fn salt_size(&self) -> usize {
128
}
fn chunk_size(&self) -> usize {
3920 //253*15+125
}
fn hash(&self, inputs: &Vec<BigUint>) -> Result<BigUint, ProverError> {
Ok(self.poseidon.hash(inputs))
}
impl Prove for ProverNode {
/// Produces a groth16 proof with snarkjs. Input must be a JSON string in the
/// "input.json" format which snarkjs expects.
fn prove(&self, input: String, proving_key: &Vec<u8>) -> Result<Vec<u8>, ProverError> {
fn prove(&self, input: ProofInput) -> Result<Vec<u8>, ProverError> {
let mut path1 = temp_dir();
let mut path2 = temp_dir();
let mut path3 = temp_dir();
@@ -18,11 +87,13 @@ impl Prove for ProverNode {
path2.push(format!("proving_key.zkey.{}", Uuid::new_v4()));
path3.push(format!("proof.json.{}", Uuid::new_v4()));
let input = self.create_proof_inputs(input);
fs::write(path1.clone(), input).expect("Unable to write file");
fs::write(path2.clone(), proving_key).expect("Unable to write file");
fs::write(path2.clone(), self.proving_key.clone()).expect("Unable to write file");
let output = Command::new("node")
.args([
"prove.mjs",
"circom/prove.mjs",
path1.to_str().unwrap(),
path2.to_str().unwrap(),
path3.to_str().unwrap(),
@@ -51,10 +122,10 @@ impl Prove for ProverNode {
fn check_output(output: Result<Output, std::io::Error>) -> Result<(), ProverError> {
if output.is_err() {
return Err(ProverError::SnarkjsError);
return Err(ProverError::ProvingBackendError);
}
if !output.unwrap().status.success() {
return Err(ProverError::SnarkjsError);
return Err(ProverError::ProvingBackendError);
}
Ok(())
}

View File

@@ -1,3 +1,4 @@
use crate::verifier::VerificationInput;
use crate::verifier::{VerifierError, Verify};
use json::{array, object, stringify, stringify_pretty, JsonValue};
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
@@ -7,23 +8,36 @@ use std::path::Path;
use std::process::{Command, Output};
use uuid::Uuid;
pub struct VerifierNode {}
pub struct Verifier {
verification_key: Vec<u8>,
}
impl Verifier {
pub fn new(verification_key: Vec<u8>) -> Self {
Self { verification_key }
}
}
impl Verify for VerifierNode {
fn verify(
&self,
proof: Vec<u8>,
deltas: Vec<String>,
plaintext_hash: BigUint,
labelsum_hash: BigUint,
zero_sum: BigUint,
) -> Result<bool, VerifierError> {
impl Verify for Verifier {
fn field_size(&self) -> usize {
254
}
fn useful_bits(&self) -> usize {
253
}
fn chunk_size(&self) -> usize {
3920 //253*15+125
}
fn verify(&self, input: VerificationInput) -> Result<bool, VerifierError> {
// public.json is a flat array
let mut public_json: Vec<String> = Vec::new();
public_json.push(plaintext_hash.to_string());
public_json.push(labelsum_hash.to_string());
public_json.extend::<Vec<String>>(deltas);
public_json.push(zero_sum.to_string());
public_json.push(input.plaintext_hash.to_string());
public_json.push(input.label_sum_hash.to_string());
let delta_str: Vec<String> = input.deltas.iter().map(|v| v.to_string()).collect();
public_json.extend::<Vec<String>>(delta_str);
public_json.push(input.sum_of_zero_labels.to_string());
let s = stringify(JsonValue::from(public_json.clone()));
// write into temp files and delete the files after verification
@@ -32,11 +46,11 @@ impl Verify for VerifierNode {
path1.push(format!("public.json.{}", Uuid::new_v4()));
path2.push(format!("proof.json.{}", Uuid::new_v4()));
fs::write(path1.clone(), s).expect("Unable to write file");
fs::write(path2.clone(), proof).expect("Unable to write file");
fs::write(path2.clone(), input.proof).expect("Unable to write file");
let output = Command::new("node")
.args([
"verify.mjs",
"circom/verify.mjs",
path1.to_str().unwrap(),
path2.to_str().unwrap(),
])
@@ -54,10 +68,10 @@ impl Verify for VerifierNode {
fn check_output(output: &Result<Output, std::io::Error>) -> Result<(), VerifierError> {
if output.is_err() {
return Err(VerifierError::SnarkjsError);
return Err(VerifierError::VerifyingBackendError);
}
if !output.as_ref().unwrap().status.success() {
return Err(VerifierError::SnarkjsError);
return Err(VerifierError::VerifyingBackendError);
}
Ok(())
}

190
src/utils.rs Normal file
View File

@@ -0,0 +1,190 @@
use crate::{Delta, ZeroSum};
use aes::{Aes128, NewBlockCipher};
use ark_ff::BigInt;
use cipher::{consts::U16, generic_array::GenericArray, BlockEncrypt};
use num::BigUint;
use sha2::{Digest, Sha256};
/// Converts bits in MSB-first order into a `BigUint`
pub fn bits_to_bigint(bits: &[bool]) -> BigUint {
BigUint::from_bytes_be(&boolvec_to_u8vec(&bits))
}
#[test]
fn test_bits_to_bigint() {
let bits = [true, false];
assert_eq!(bits_to_bigint(&bits), 2u8.into());
}
/// Converts bits in MSB-first order into BE bytes. The bits will be left-padded
/// with zeroes to the nearest multiple of 8.
pub fn boolvec_to_u8vec(bv: &[bool]) -> Vec<u8> {
let rem = bv.len() % 8;
let first_byte_bitsize = if rem == 0 { 8 } else { rem };
let offset = if rem == 0 { 0 } else { 1 };
let mut v = vec![0u8; bv.len() / 8 + offset];
// implicitely left-pad the first byte with zeroes
for (i, b) in bv[0..first_byte_bitsize].iter().enumerate() {
v[i / 8] |= (*b as u8) << (first_byte_bitsize - 1 - i);
}
for (i, b) in bv[first_byte_bitsize..].iter().enumerate() {
v[1 + i / 8] |= (*b as u8) << (7 - (i % 8));
}
v
}
#[test]
fn test_boolvec_to_u8vec() {
let bits = [true, false];
assert_eq!(boolvec_to_u8vec(&bits), [2]);
let bits = [true, false, false, false, false, false, false, true, true];
assert_eq!(boolvec_to_u8vec(&bits), [1, 3]);
}
/// Converts BE bytes into bits in MSB-first order, left-padding with zeroes
/// to the nearest multiple of 8.
pub fn u8vec_to_boolvec(v: &[u8]) -> Vec<bool> {
let mut bv = Vec::with_capacity(v.len() * 8);
for byte in v.iter() {
for i in 0..8 {
bv.push(((byte >> (7 - i)) & 1) != 0);
}
}
bv
}
#[test]
fn test_u8vec_to_boolvec() {
let bytes = [1];
assert_eq!(
u8vec_to_boolvec(&bytes),
[false, false, false, false, false, false, false, true]
);
let bytes = [255, 2];
assert_eq!(
u8vec_to_boolvec(&bytes),
[
true, true, true, true, true, true, true, true, false, false, false, false, false,
false, true, false
]
);
// convert to bits and back to bytes
let bignum: BigUint = 3898219876643u128.into();
let bits = u8vec_to_boolvec(&bignum.to_bytes_be());
let bytes = boolvec_to_u8vec(&bits);
assert_eq!(bignum, BigUint::from_bytes_be(&bytes));
}
/// Returns sha256 hash digest
pub fn sha256(data: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().into()
}
/// Encrypts each arithmetic label using a corresponding binary label as a key
/// and returns ciphertexts in an order based on binary label's pointer bit (LSB).
pub fn encrypt_arithmetic_labels(
alabels: &Vec<[BigUint; 2]>,
blabels: &Vec<[u128; 2]>,
) -> Result<Vec<[[u8; 16]; 2]>, String> {
if alabels.len() > blabels.len() {
return Err("error".to_string());
}
Ok(blabels
.iter()
.zip(alabels)
.map(|(bin_pair, arithm_pair)| {
// safe to unwrap() since to_be_bytes() always returns exactly 16
// bytes for u128
let zero_key = Aes128::new_from_slice(&bin_pair[0].to_be_bytes()).unwrap();
let one_key = Aes128::new_from_slice(&bin_pair[1].to_be_bytes()).unwrap();
let mut label0 = [0u8; 16];
let mut label1 = [0u8; 16];
let ap0 = arithm_pair[0].to_bytes_be();
let ap1 = arithm_pair[1].to_bytes_be();
// pad with zeroes on the left
label0[16 - ap0.len()..].copy_from_slice(&ap0);
label1[16 - ap1.len()..].copy_from_slice(&ap1);
let mut label0: GenericArray<u8, U16> = GenericArray::from(label0);
let mut label1: GenericArray<u8, U16> = GenericArray::from(label1);
zero_key.encrypt_block(&mut label0);
one_key.encrypt_block(&mut label1);
// place encrypted arithmetic labels based on the pointer bit of
// binary label 0
if (bin_pair[0] & 1) == 0 {
[label0.into(), label1.into()]
} else {
[label1.into(), label0.into()]
}
})
.collect())
}
#[test]
fn test_encrypt_arithmetic_labels() {
let alabels: [BigUint; 2] = [3u8.into(), 4u8.into()];
let blabels = [0u128, 1u128];
let res = encrypt_arithmetic_labels(&vec![alabels], &vec![blabels]).unwrap();
let flat = res[0]
.into_iter()
.map(|ct| ct)
.flatten()
.collect::<Vec<_>>();
// expected value generated with python3:
// from Crypto.Cipher import AES
// k0 = AES.new((0).to_bytes(16, 'big'), AES.MODE_ECB)
// ct0 = k0.encrypt((3).to_bytes(16, 'big')).hex()
// k1 = AES.new((1).to_bytes(16, 'big'), AES.MODE_ECB)
// ct1 = k1.encrypt((4).to_bytes(16, 'big')).hex()
// print(ct0+ct1)
let expected = "f795aaab494b5923f7fd89ff948bc1e0382fa171550467b34c54c58b9d3cfd24";
assert_eq!(hex::encode(&flat), expected);
}
/// Returns the sum of all zero labels and deltas for each label pair.
pub fn compute_zero_sum_and_deltas(
arithmetic_label_pairs: &[[BigUint; 2]],
) -> (ZeroSum, Vec<Delta>) {
let mut deltas: Vec<Delta> = Vec::with_capacity(arithmetic_label_pairs.len());
let mut zero_sum: ZeroSum = 0u8.into();
for label_pair in arithmetic_label_pairs {
// calculate the sum of all zero labels
zero_sum += label_pair[0].clone();
// put deltas from into one vec
deltas.push(label_pair[1].clone() - label_pair[0].clone());
}
(zero_sum, deltas)
}
#[test]
/// Tests compute_zero_sum_and_deltas()
fn test_compute_zero_sum_and_deltas() {
let labels: [[BigUint; 2]; 2] = [[1u8.into(), 2u8.into()], [3u8.into(), 4u8.into()]];
let (z, d) = compute_zero_sum_and_deltas(&labels);
assert_eq!(z, 4u8.into());
assert_eq!(d, [1u8.into(), 1u8.into()]);
}
/// Make sure that the `BigUint`s bitsize is not larger than `bitsize`
pub fn sanitize_biguint(input: &BigUint, bitsize: usize) -> Result<(), String> {
if (input.bits() as usize) > bitsize {
return Err("error".to_string());
} else {
Ok(())
}
}
#[test]
/// Tests sanitize_biguint()
fn test_sanitize_biguint() {
let good = BigUint::from(2u8).pow(253) - BigUint::from(1u8);
let res = sanitize_biguint(&good, 253);
assert!(!res.is_err());
let bad = BigUint::from(2u8).pow(253);
let res = sanitize_biguint(&bad, 253);
assert!(res.is_err());
}

View File

@@ -1,23 +1,26 @@
use crate::boolvec_to_u8vec;
use super::ARITHMETIC_LABEL_SIZE;
use crate::label::{LabelGenerator, Seed};
use crate::utils::{compute_zero_sum_and_deltas, encrypt_arithmetic_labels, sanitize_biguint};
use crate::{Delta, LabelsumHash, PlaintextHash, Proof, ZeroSum};
use num::BigUint;
use super::{encrypt_arithmetic_labels, random_bigint, ARITHMETIC_LABEL_SIZE, POSEIDON_RATE};
use crate::label::{LabelGenerator, LabelPair, Seed};
use aes::{Aes128, NewBlockCipher};
use cipher::{consts::U16, generic_array::GenericArray, BlockCipher, BlockEncrypt};
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
use rand::SeedableRng;
use rand::{thread_rng, Rng};
use rand_chacha::ChaCha20Rng;
// The PRG we use to generate arithmetic labels
type Prg = ChaCha20Rng;
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum VerifierError {
ProvingKeyNotFound,
FileSystemError,
FileDoesNotExist,
SnarkjsError,
WrongProofCount,
BigUintTooLarge,
VerifyingBackendError,
VerificationFailed,
InternalError,
}
/// Public inputs and a zk proof that needs to be verified.
#[derive(Default)]
pub struct VerificationInput {
pub plaintext_hash: PlaintextHash,
pub label_sum_hash: LabelsumHash,
pub sum_of_zero_labels: ZeroSum,
pub deltas: Vec<Delta>,
pub proof: Proof,
}
pub trait State {}
@@ -25,51 +28,57 @@ pub trait State {}
pub struct Setup {
binary_labels: Vec<[u128; 2]>,
}
impl State for Setup {}
#[derive(Default)]
pub struct ReceivePlaintextHashes {
deltas: Vec<BigUint>,
/// The sum of all arithmetic labels with the semantic value 0. One sum
/// for each chunk of the plaintext.
zero_sums: Vec<BigUint>,
ciphertexts: Vec<[Vec<u8>; 2]>,
/// The PRG seed from which all arithmetic labels were generated
deltas: Vec<Delta>,
zero_sums: Vec<ZeroSum>,
ciphertexts: Vec<[[u8; 16]; 2]>,
arith_label_seed: Seed,
}
impl State for ReceivePlaintextHashes {}
#[derive(Default)]
pub struct ReceiveLabelsumHashes {
deltas: Vec<BigUint>,
zero_sums: Vec<BigUint>,
// hashes for each chunk of Prover's plaintext
plaintext_hashes: Vec<BigUint>,
arith_label_seed: [u8; 32],
deltas: Vec<Delta>,
zero_sums: Vec<ZeroSum>,
plaintext_hashes: Vec<PlaintextHash>,
arith_label_seed: Seed,
}
impl State for ReceiveLabelsumHashes {}
#[derive(Default)]
pub struct VerifyMany {
deltas: Vec<BigUint>,
zero_sums: Vec<BigUint>,
plaintext_hashes: Vec<BigUint>,
labelsum_hashes: Vec<BigUint>,
deltas: Vec<Delta>,
zero_sums: Vec<ZeroSum>,
plaintext_hashes: Vec<PlaintextHash>,
labelsum_hashes: Vec<LabelsumHash>,
}
impl State for VerifyMany {}
pub struct VerificationSuccessfull {
plaintext_hashes: Vec<BigUint>,
plaintext_hashes: Vec<PlaintextHash>,
}
impl State for Setup {}
impl State for ReceivePlaintextHashes {}
impl State for ReceiveLabelsumHashes {}
impl State for VerifyMany {}
impl State for VerificationSuccessfull {}
pub trait Verify {
fn verify(
&self,
proof: Vec<u8>,
deltas: Vec<String>,
plaintext_hash: BigUint,
labelsum_hash: BigUint,
zero_sum: BigUint,
) -> Result<bool, VerifierError>;
/// Verifies the zk proof against public `input`s. Returns `true` on success,
/// `false` otherwise.
fn verify(&self, input: VerificationInput) -> Result<bool, VerifierError>;
/// The EC field size in bits. Verifier uses this to sanitize the `BigUint`s
/// received from Prover.
fn field_size(&self) -> usize;
/// Returns how many bits of plaintext we will pack into one field element.
/// Normally, this should be [Verify::field_size] minus 1.
fn useful_bits(&self) -> usize;
/// How many bits of [Plaintext] can fit into one [Chunk]. This does not
/// include the [Salt] of the hash - which takes up the remaining least bits
/// of the last field element of each chunk.
fn chunk_size(&self) -> usize;
}
pub struct LabelsumVerifier<S = Setup>
where
@@ -80,6 +89,7 @@ where
}
impl LabelsumVerifier {
/// Returns the next expected state.
pub fn new(
binary_labels: Vec<[u128; 2]>,
verifier: Box<dyn Verify>,
@@ -92,40 +102,33 @@ impl LabelsumVerifier {
}
impl LabelsumVerifier<Setup> {
/// Generates arith. labels from a seed and encrypts them using binary labels
/// as encryption keys.
/// Generates arithmetic labels from a seed, computes the deltas, computes
/// the sum of zero labels, encrypts arithmetic labels using binary
/// labels as encryption keys.
///
/// Returns the next expected state.
pub fn setup(self) -> Result<LabelsumVerifier<ReceivePlaintextHashes>, VerifierError> {
let plaintext_bitsize = self.state.binary_labels.len();
// Compute useful bits from the field prime
let chunk_size = 253 * POSEIDON_RATE - 128;
// count of chunks rounded up
let chunk_count = (plaintext_bitsize + (chunk_size - 1)) / chunk_size;
// There will be as many deltas as there are garbled circuit output
// labels.
let mut deltas: Vec<BigUint> = Vec::with_capacity(self.state.binary_labels.len());
// There will be as many zero_sums as there are chunks
let mut zero_sums: Vec<BigUint> = Vec::with_capacity(chunk_count);
let (label_pairs, seed) =
LabelGenerator::generate(self.state.binary_labels.len(), ARITHMETIC_LABEL_SIZE);
// There will be as many deltas as there are plaintext bits.
let mut deltas: Vec<BigUint> = Vec::with_capacity(plaintext_bitsize);
let zero_sums: Vec<ZeroSum> = label_pairs
.chunks(self.verifier.chunk_size())
.map(|chunk_of_alabel_pairs| {
let (zero_sum, deltas_in_chunk) =
compute_zero_sum_and_deltas(chunk_of_alabel_pairs);
deltas.extend(deltas_in_chunk);
zero_sum
})
.collect();
// Generate arithmetic label pairs and split them into chunks
let generator = LabelGenerator::new();
let (label_pairs, seed, _generator) =
generator.generate(plaintext_bitsize, ARITHMETIC_LABEL_SIZE);
let label_pair_chunks = label_pairs.chunks(chunk_size);
// Calculate deltas for all chunks and zero_sums for each chunk
for chunk in label_pair_chunks {
let mut zero_sum = BigUint::from_u8(0).unwrap();
for label_pair in chunk {
zero_sum += label_pair[0].clone();
deltas.push(label_pair[1].clone() - label_pair[0].clone());
}
zero_sums.push(zero_sum);
}
// encrypt each arithmetic label using a corresponding binary label as a key
// place ciphertexts in an order based on binary label's p&p bit
let ciphertexts = encrypt_arithmetic_labels(&label_pairs, &self.state.binary_labels);
let ciphertexts = match encrypt_arithmetic_labels(&label_pairs, &self.state.binary_labels) {
Ok(ct) => ct,
Err(_) => return Err(VerifierError::InternalError),
};
Ok(LabelsumVerifier {
state: ReceivePlaintextHashes {
@@ -140,12 +143,19 @@ impl LabelsumVerifier<Setup> {
}
impl LabelsumVerifier<ReceivePlaintextHashes> {
// receive hashes of plaintext and reveal the encrypted arithmetic labels
/// Receives hashes of plaintext and returns the encrypted
/// arithmetic labels and the next expected state.
pub fn receive_plaintext_hashes(
self,
plaintext_hashes: Vec<BigUint>,
) -> (Vec<[Vec<u8>; 2]>, LabelsumVerifier<ReceiveLabelsumHashes>) {
(
plaintext_hashes: Vec<PlaintextHash>,
) -> Result<(Vec<[[u8; 16]; 2]>, LabelsumVerifier<ReceiveLabelsumHashes>), VerifierError> {
for h in &plaintext_hashes {
if sanitize_biguint(h, self.verifier.field_size()).is_err() {
return Err(VerifierError::BigUintTooLarge);
}
}
Ok((
self.state.ciphertexts,
LabelsumVerifier {
state: ReceiveLabelsumHashes {
@@ -156,18 +166,24 @@ impl LabelsumVerifier<ReceivePlaintextHashes> {
},
verifier: self.verifier,
},
)
))
}
}
impl LabelsumVerifier<ReceiveLabelsumHashes> {
// receive the hash commitment to the Prover's sum of labels and reveal
// the arithmetic label seed
/// Receives hashes of sums of labels and returns the arithmetic label [Seed]
/// and the next expected state.
pub fn receive_labelsum_hashes(
self,
labelsum_hashes: Vec<BigUint>,
) -> (Seed, LabelsumVerifier<VerifyMany>) {
(
labelsum_hashes: Vec<LabelsumHash>,
) -> Result<(Seed, LabelsumVerifier<VerifyMany>), VerifierError> {
for h in &labelsum_hashes {
if sanitize_biguint(h, self.verifier.field_size()).is_err() {
return Err(VerifierError::BigUintTooLarge);
}
}
Ok((
self.state.arith_label_seed,
LabelsumVerifier {
state: VerifyMany {
@@ -178,60 +194,27 @@ impl LabelsumVerifier<ReceiveLabelsumHashes> {
},
verifier: self.verifier,
},
)
))
}
}
impl LabelsumVerifier<VerifyMany> {
/// Verifies as many proofs as there are [Chunk]s of the plaintext. Returns
/// the next expected state.
pub fn verify_many(
self,
proofs: Vec<Vec<u8>>,
mut self,
proofs: Vec<Proof>,
) -> Result<LabelsumVerifier<VerificationSuccessfull>, VerifierError> {
// // Write public.json. The elements must be written in the exact order
// // as below, that's the order snarkjs expects them to be in.
// the last chunk will be padded with zero plaintext. We also should pad
// the deltas of the last chunk
// TODO remove this hard-coding
let useful_bits = 253;
// the size of a chunk of plaintext not counting the salt
let chunk_size = useful_bits * 16 - 128;
let chunk_count = (self.state.deltas.len() + (chunk_size - 1)) / chunk_size;
assert!(proofs.len() == chunk_count);
// pad deltas with 0 values to make their count a multiple of a chunk size
let delta_pad_count = chunk_size * chunk_count - self.state.deltas.len();
let mut deltas = self.state.deltas.clone();
deltas.extend(vec![BigUint::from_u8(0).unwrap(); delta_pad_count]);
let deltas_chunks: Vec<&[BigUint]> = deltas.chunks(chunk_size).collect();
for count in 0..chunk_count {
// There are as many deltas as there are bits in the chunk of the
// plaintext (not counting the salt)
let delta_str: Vec<String> =
deltas_chunks[count].iter().map(|v| v.to_string()).collect();
let plaintext_hash = self.state.plaintext_hashes[count].clone();
let labelsum_hash = self.state.labelsum_hashes[count].clone();
let zero_sum = self.state.zero_sums[count].clone();
let res = self.verifier.verify(
proofs[count].clone(),
delta_str,
plaintext_hash,
labelsum_hash,
zero_sum,
);
// checking both for good measure
if res.is_err() {
return Err(VerifierError::VerificationFailed);
}
// shouldn't get here if there was an error, but will check anyway
if res.unwrap() != true {
let inputs = self.create_verification_inputs(proofs)?;
for input in inputs {
let res = self.verifier.verify(input)?;
if res != true {
// we will never get here since "?" takes care of the
// verification error. Still, it is good to have this check
// just in case.
return Err(VerifierError::VerificationFailed);
}
}
Ok(LabelsumVerifier {
state: VerificationSuccessfull {
plaintext_hashes: self.state.plaintext_hashes,
@@ -239,4 +222,208 @@ impl LabelsumVerifier<VerifyMany> {
verifier: self.verifier,
})
}
/// Construct public inputs for the zk circuit for each [Chunk].
fn create_verification_inputs(
&mut self,
proofs: Vec<Proof>,
) -> Result<Vec<VerificationInput>, VerifierError> {
// How many chunks of plaintext are there? ( == how many zk proofs to expect)
// The amount of deltas corresponds to the amount of bits in the plaintext.
// Round up the chunk count.
let chunk_count = (self.state.deltas.len() + (self.verifier.chunk_size() - 1))
/ self.verifier.chunk_size();
if proofs.len() != chunk_count {
return Err(VerifierError::WrongProofCount);
}
// Since the last chunk of plaintext is padded with zero bits, we also zero-pad
// the corresponding deltas of the last chunk to the size of a chunk.
let delta_pad_count = self.verifier.chunk_size() * chunk_count - self.state.deltas.len();
let mut deltas = self.state.deltas.clone();
deltas.extend(vec![0u8.into(); delta_pad_count]);
let chunks_of_deltas = deltas
.chunks(self.verifier.chunk_size())
.map(|i| i.to_vec())
.collect::<Vec<Vec<_>>>();
Ok((0..chunk_count)
.map(|i| VerificationInput {
plaintext_hash: self.state.plaintext_hashes[i].clone(),
label_sum_hash: self.state.labelsum_hashes[i].clone(),
sum_of_zero_labels: self.state.zero_sums[i].clone(),
deltas: chunks_of_deltas[i].clone(),
proof: proofs[i].clone(),
})
.collect())
}
}
#[cfg(test)]
mod tests {
use crate::verifier::LabelsumVerifier;
use crate::verifier::ReceiveLabelsumHashes;
use crate::verifier::ReceivePlaintextHashes;
use crate::verifier::VerificationInput;
use crate::verifier::VerifierError;
use crate::verifier::Verify;
use crate::verifier::VerifyMany;
use crate::Proof;
use num::BigUint;
/// The verifier who implements `Verify` with the correct values
struct CorrectTestVerifier {}
impl Verify for CorrectTestVerifier {
fn verify(&self, _input: VerificationInput) -> Result<bool, VerifierError> {
Ok(true)
}
fn field_size(&self) -> usize {
254
}
fn useful_bits(&self) -> usize {
253
}
fn chunk_size(&self) -> usize {
3670
}
}
#[test]
/// Provide `BigUint` larger than useful_bits() and trigger
/// [VerifierError::BigUintTooLarge]
fn test_error_biguint_too_large() {
// test receive_plaintext_hashes()
let lsv = LabelsumVerifier {
state: ReceivePlaintextHashes::default(),
verifier: Box::new(CorrectTestVerifier {}),
};
let mut hashes: Vec<BigUint> = (0..100).map(|i| BigUint::from(i as u64)).collect();
hashes[50] = BigUint::from(2u8).pow(lsv.verifier.field_size() as u32);
let res = lsv.receive_plaintext_hashes(hashes);
assert_eq!(res.err().unwrap(), VerifierError::BigUintTooLarge);
// test receive_labelsum_hashes
let lsv = LabelsumVerifier {
state: ReceiveLabelsumHashes::default(),
verifier: Box::new(CorrectTestVerifier {}),
};
let mut plaintext_hashes: Vec<BigUint> =
(0..100).map(|i| BigUint::from(i as u64)).collect();
plaintext_hashes[50] = BigUint::from(2u8).pow(lsv.verifier.field_size() as u32);
let res = lsv.receive_labelsum_hashes(plaintext_hashes);
assert_eq!(res.err().unwrap(), VerifierError::BigUintTooLarge);
}
#[test]
/// Provide too many/too few proofs and trigger [VerifierError::WrongProofCount]
fn test_error_wrong_proof_count() {
// 3 chunks
let lsv = LabelsumVerifier {
state: VerifyMany {
deltas: vec![0u8.into(); 3670 * 2 + 1],
..Default::default()
},
verifier: Box::new(CorrectTestVerifier {}),
};
// 4 proofs
let res = lsv.verify_many(vec![Proof::default(); 4]);
assert_eq!(res.err().unwrap(), VerifierError::WrongProofCount);
// 3 chunks
let lsv = LabelsumVerifier {
state: VerifyMany {
deltas: vec![0u8.into(); 3670 * 2 + 1],
..Default::default()
},
verifier: Box::new(CorrectTestVerifier {}),
};
// 2 proofs
let res = lsv.verify_many(vec![Proof::default(); 2]);
assert_eq!(res.err().unwrap(), VerifierError::WrongProofCount);
}
#[test]
/// Returns `false` when attempting to verify and triggers
/// [VerifierError::VerificationFailed]
fn test_error_verification_failed() {
struct TestVerifier {}
impl Verify for TestVerifier {
fn verify(&self, _input: VerificationInput) -> Result<bool, VerifierError> {
Ok(false)
}
fn field_size(&self) -> usize {
254
}
fn useful_bits(&self) -> usize {
253
}
fn chunk_size(&self) -> usize {
3670
}
}
let lsv = LabelsumVerifier {
state: VerifyMany {
deltas: vec![0u8.into(); 3670 * 2 - 1],
zero_sums: vec![0u8.into(); 2],
plaintext_hashes: vec![0u8.into(); 2],
labelsum_hashes: vec![0u8.into(); 2],
},
verifier: Box::new(TestVerifier {}),
};
let res = lsv.verify_many(vec![Proof::default(); 2]);
assert_eq!(res.err().unwrap(), VerifierError::VerificationFailed);
}
#[test]
/// Returns some other error not related to the verification result when
/// attempting to verify and checks that the error propagates.
fn test_verification_error() {
struct TestVerifier {}
impl Verify for TestVerifier {
fn verify(&self, _input: VerificationInput) -> Result<bool, VerifierError> {
Err(VerifierError::VerifyingBackendError)
}
fn field_size(&self) -> usize {
254
}
fn useful_bits(&self) -> usize {
253
}
fn chunk_size(&self) -> usize {
3670
}
}
let lsv = LabelsumVerifier {
state: VerifyMany {
deltas: vec![0u8.into(); 3670 * 2 - 1],
zero_sums: vec![0u8.into(); 2],
plaintext_hashes: vec![0u8.into(); 2],
labelsum_hashes: vec![0u8.into(); 2],
},
verifier: Box::new(TestVerifier {}),
};
let res = lsv.verify_many(vec![Proof::default(); 2]);
assert_eq!(res.err().unwrap(), VerifierError::VerifyingBackendError);
}
}