mirror of
https://github.com/tlsnotary/label_decoding.git
synced 2026-01-08 03:33:52 -05:00
halo2 wip and tests
This commit is contained in:
@@ -52,3 +52,6 @@ ark-std = { git = "https://github.com/arkworks-rs/std" }
|
||||
ark-ec = { git = "https://github.com/arkworks-rs/algebra" }
|
||||
ark-ff = { git = "https://github.com/arkworks-rs/algebra" }
|
||||
ark-serialize = { git = "https://github.com/arkworks-rs/algebra" }
|
||||
|
||||
[dev-dependencies]
|
||||
hex = "0.4"
|
||||
29
README
29
README
@@ -1,29 +0,0 @@
|
||||
This repo generates a circom circuit which is used to decode output labels from GC.
|
||||
|
||||
Install snarkjs https://github.com/iden3/snarkjs
|
||||
Download powers of tau^14 https://hermez.s3-eu-west-1.amazonaws.com/powersOfTau28_hez_final_14.ptau
|
||||
|
||||
Run:
|
||||
python3 script.py 15
|
||||
# 10 is how much plaintext (in Field elements of ~32 bytes) we want
|
||||
# to decode inside the snark. (For tau^14 max is 21)
|
||||
# if you need more than 21, you'll need to download another ptau file from
|
||||
# https://github.com/iden3/snarkjs#7-prepare-phase-2
|
||||
circom circuit.circom --r1cs --wasm
|
||||
|
||||
snarkjs groth16 setup circuit.r1cs powersOfTau28_hez_final_14.ptau circuit_0000.zkey
|
||||
|
||||
# snarkjs groth16 setup circuit.r1cs pot14_bls12_final.ptau circuit_0000.zkey
|
||||
|
||||
snarkjs zkey contribute circuit_0000.zkey circuit_final.zkey -v -e="Notary's entropy"
|
||||
snarkjs zkey export verificationkey circuit_final.zkey verification_key.json
|
||||
snarkjs groth16 fullprove input.json circuit_js/circuit.wasm circuit_final.zkey proof.json public.json
|
||||
|
||||
|
||||
snarkjs groth16 verify verification_key.json public.json proof.json
|
||||
|
||||
|
||||
|
||||
We can generate circuit.wasm and circuit.r1cs deterministically with circom 2.0.5+
|
||||
circom circuit.circom --r1cs --wasm
|
||||
and then ship .wasm on the User side and .r1cs on the Notary side
|
||||
24
circom/README
Normal file
24
circom/README
Normal file
@@ -0,0 +1,24 @@
|
||||
This folder contains files used by the snarkjs_backend.
|
||||
To use that backend, make sure you have node installed (tested on Node v16.17.1)
|
||||
Install dependencies with:
|
||||
|
||||
npm install
|
||||
|
||||
powersOfTau28_hez_final_14.ptau was downloaded from https://hermez.s3-eu-west-1.amazonaws.com/powersOfTau28_hez_final_14.ptau
|
||||
|
||||
Whenever circuit.circom is modified, delete circuit_0000.zkey and run:
|
||||
circom circuit.circom --r1cs --wasm
|
||||
|
||||
All the commands below will be run by the prover/verifier from the .mjs files:
|
||||
|
||||
snarkjs groth16 setup circuit.r1cs powersOfTau28_hez_final_14.ptau circuit_0000.zkey
|
||||
snarkjs zkey contribute circuit_0000.zkey circuit_final.zkey -v -e="Notary's entropy"
|
||||
snarkjs zkey export verificationkey circuit_final.zkey verification_key.json
|
||||
snarkjs groth16 fullprove input.json circuit_js/circuit.wasm circuit_final.zkey proof.json public.json
|
||||
snarkjs groth16 verify verification_key.json public.json proof.json
|
||||
|
||||
|
||||
|
||||
We can generate circuit.wasm and circuit.r1cs deterministically with circom 2.0.5+
|
||||
circom circuit.circom --r1cs --wasm
|
||||
and then ship circuit.wasm on the User side and circuit.r1cs on the Notary side
|
||||
@@ -3,7 +3,7 @@ include "./poseidon.circom";
|
||||
include "./utils.circom";
|
||||
|
||||
template Main() {
|
||||
// Poseidon hash width (how many field elements are permuted at a time)
|
||||
// Poseidon hash rate (how many field elements are permuted at a time)
|
||||
var w = 16;
|
||||
// The amount of last field element's high bits (in big-endian) to use for
|
||||
// the plaintext. The rest of it will be used for the salt.
|
||||
@@ -12,19 +12,20 @@ template Main() {
|
||||
signal input plaintext_hash;
|
||||
signal input label_sum_hash;
|
||||
signal input plaintext[w];
|
||||
signal input labelsum_salt;
|
||||
signal input salt;
|
||||
signal input delta[w-1][253];
|
||||
signal input delta_last[last_fe_bits];
|
||||
signal input sum_of_zero_labels;
|
||||
signal sums[w];
|
||||
|
||||
// acc.to the Poseidon paper, the 2nd element of the Poseidon state
|
||||
// is the hash digest
|
||||
component hash = PoseidonEx(w, 2);
|
||||
hash.initialState <== 0;
|
||||
for (var i = 0; i<w; i++) {
|
||||
for (var i = 0; i < w-1; i++) {
|
||||
hash.inputs[i] <== plaintext[i];
|
||||
}
|
||||
//add salt to the last element of plaintext shifting it left first
|
||||
hash.inputs[w-1] <== plaintext[w-1] * (1 << 128) + salt;
|
||||
log(1);
|
||||
plaintext_hash === hash.out[1];
|
||||
log(2);
|
||||
@@ -41,7 +42,7 @@ template Main() {
|
||||
var useful_bits = i < w-1 ? 253 : last_fe_bits;
|
||||
ip[i] = InnerProd(useful_bits);
|
||||
ip[i].plaintext <== plaintext[i];
|
||||
for (var j=0; j<useful_bits; j++) {
|
||||
for (var j=0; j < useful_bits; j++) {
|
||||
if (i < w-1){
|
||||
ip[i].deltas[j] <== delta[i][j];
|
||||
}
|
||||
@@ -57,8 +58,9 @@ template Main() {
|
||||
component ls_hash = PoseidonEx(1, 2);
|
||||
ls_hash.initialState <== 0;
|
||||
// shift the sum to the left and put the salt into the last 128 bits
|
||||
ls_hash.inputs[0] <== (sum_of_zero_labels + sum_of_deltas[w]) * (1 << 128) + labelsum_salt;
|
||||
ls_hash.inputs[0] <== (sum_of_zero_labels + sum_of_deltas[w]) * (1 << 128) + salt;
|
||||
log(3);
|
||||
label_sum_hash === ls_hash.out[1];
|
||||
log(4);
|
||||
}
|
||||
component main {public [sum_of_zero_labels, plaintext_hash, label_sum_hash, delta, delta_last]} = Main();
|
||||
BIN
circom/circuit.r1cs
Normal file
BIN
circom/circuit.r1cs
Normal file
Binary file not shown.
BIN
circom/circuit_js/circuit.wasm
Normal file
BIN
circom/circuit_js/circuit.wasm
Normal file
Binary file not shown.
20
circom/circuit_js/generate_witness.js
Normal file
20
circom/circuit_js/generate_witness.js
Normal file
@@ -0,0 +1,20 @@
|
||||
const wc = require("./witness_calculator.js");
|
||||
const { readFileSync, writeFile } = require("fs");
|
||||
|
||||
if (process.argv.length != 5) {
|
||||
console.log("Usage: node generate_witness.js <file.wasm> <input.json> <output.wtns>");
|
||||
} else {
|
||||
const input = JSON.parse(readFileSync(process.argv[3], "utf8"));
|
||||
|
||||
const buffer = readFileSync(process.argv[2]);
|
||||
wc(buffer).then(async witnessCalculator => {
|
||||
// const w= await witnessCalculator.calculateWitness(input,0);
|
||||
// for (let i=0; i< w.length; i++){
|
||||
// console.log(w[i]);
|
||||
// }
|
||||
const buff= await witnessCalculator.calculateWTNSBin(input,0);
|
||||
writeFile(process.argv[4], buff, function(err) {
|
||||
if (err) throw err;
|
||||
});
|
||||
});
|
||||
}
|
||||
306
circom/circuit_js/witness_calculator.js
Normal file
306
circom/circuit_js/witness_calculator.js
Normal file
@@ -0,0 +1,306 @@
|
||||
module.exports = async function builder(code, options) {
|
||||
|
||||
options = options || {};
|
||||
|
||||
let wasmModule;
|
||||
try {
|
||||
wasmModule = await WebAssembly.compile(code);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.log("\nTry to run circom --c in order to generate c++ code instead\n");
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
let wc;
|
||||
|
||||
|
||||
const instance = await WebAssembly.instantiate(wasmModule, {
|
||||
runtime: {
|
||||
exceptionHandler : function(code) {
|
||||
let errStr;
|
||||
if (code == 1) {
|
||||
errStr= "Signal not found. ";
|
||||
} else if (code == 2) {
|
||||
errStr= "Too many signals set. ";
|
||||
} else if (code == 3) {
|
||||
errStr= "Signal already set. ";
|
||||
} else if (code == 4) {
|
||||
errStr= "Assert Failed. ";
|
||||
} else if (code == 5) {
|
||||
errStr= "Not enough memory. ";
|
||||
} else if (code == 6) {
|
||||
errStr= "Input signal array access exceeds the size";
|
||||
} else {
|
||||
errStr= "Unknown error\n";
|
||||
}
|
||||
// get error message from wasm
|
||||
errStr += getMessage();
|
||||
throw new Error(errStr);
|
||||
},
|
||||
showSharedRWMemory: function() {
|
||||
printSharedRWMemory ();
|
||||
}
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
const sanityCheck =
|
||||
options
|
||||
// options &&
|
||||
// (
|
||||
// options.sanityCheck ||
|
||||
// options.logGetSignal ||
|
||||
// options.logSetSignal ||
|
||||
// options.logStartComponent ||
|
||||
// options.logFinishComponent
|
||||
// );
|
||||
|
||||
|
||||
wc = new WitnessCalculator(instance, sanityCheck);
|
||||
return wc;
|
||||
|
||||
function getMessage() {
|
||||
var message = "";
|
||||
var c = instance.exports.getMessageChar();
|
||||
while ( c != 0 ) {
|
||||
message += String.fromCharCode(c);
|
||||
c = instance.exports.getMessageChar();
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function printSharedRWMemory () {
|
||||
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
|
||||
const arr = new Uint32Array(shared_rw_memory_size);
|
||||
for (let j=0; j<shared_rw_memory_size; j++) {
|
||||
arr[shared_rw_memory_size-1-j] = instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
console.log(fromArray32(arr));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class WitnessCalculator {
|
||||
constructor(instance, sanityCheck) {
|
||||
this.instance = instance;
|
||||
|
||||
this.version = this.instance.exports.getVersion();
|
||||
this.n32 = this.instance.exports.getFieldNumLen32();
|
||||
|
||||
this.instance.exports.getRawPrime();
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let i=0; i<this.n32; i++) {
|
||||
arr[this.n32-1-i] = this.instance.exports.readSharedRWMemory(i);
|
||||
}
|
||||
this.prime = fromArray32(arr);
|
||||
|
||||
this.witnessSize = this.instance.exports.getWitnessSize();
|
||||
|
||||
this.sanityCheck = sanityCheck;
|
||||
}
|
||||
|
||||
circom_version() {
|
||||
return this.instance.exports.getVersion();
|
||||
}
|
||||
|
||||
async _doCalculateWitness(input, sanityCheck) {
|
||||
//input is assumed to be a map from signals to arrays of bigints
|
||||
this.instance.exports.init((this.sanityCheck || sanityCheck) ? 1 : 0);
|
||||
const keys = Object.keys(input);
|
||||
var input_counter = 0;
|
||||
keys.forEach( (k) => {
|
||||
const h = fnvHash(k);
|
||||
const hMSB = parseInt(h.slice(0,8), 16);
|
||||
const hLSB = parseInt(h.slice(8,16), 16);
|
||||
const fArr = flatArray(input[k]);
|
||||
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
|
||||
if (signalSize < 0){
|
||||
throw new Error(`Signal ${k} not found\n`);
|
||||
}
|
||||
if (fArr.length < signalSize) {
|
||||
throw new Error(`Not enough values for input signal ${k}\n`);
|
||||
}
|
||||
if (fArr.length > signalSize) {
|
||||
throw new Error(`Too many values for input signal ${k}\n`);
|
||||
}
|
||||
for (let i=0; i<fArr.length; i++) {
|
||||
const arrFr = toArray32(BigInt(fArr[i])%this.prime,this.n32)
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
this.instance.exports.writeSharedRWMemory(j,arrFr[this.n32-1-j]);
|
||||
}
|
||||
try {
|
||||
this.instance.exports.setInputSignal(hMSB, hLSB,i);
|
||||
input_counter++;
|
||||
} catch (err) {
|
||||
// console.log(`After adding signal ${i} of ${k}`)
|
||||
throw new Error(err);
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
if (input_counter < this.instance.exports.getInputSize()) {
|
||||
throw new Error(`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`);
|
||||
}
|
||||
}
|
||||
|
||||
async calculateWitness(input, sanityCheck) {
|
||||
|
||||
const w = [];
|
||||
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const arr = new Uint32Array(this.n32);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
arr[this.n32-1-j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
w.push(fromArray32(arr));
|
||||
}
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
async calculateBinWitness(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
const pos = i*this.n32;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
|
||||
async calculateWTNSBin(input, sanityCheck) {
|
||||
|
||||
const buff32 = new Uint32Array(this.witnessSize*this.n32+this.n32+11);
|
||||
const buff = new Uint8Array( buff32.buffer);
|
||||
await this._doCalculateWitness(input, sanityCheck);
|
||||
|
||||
//"wtns"
|
||||
buff[0] = "w".charCodeAt(0)
|
||||
buff[1] = "t".charCodeAt(0)
|
||||
buff[2] = "n".charCodeAt(0)
|
||||
buff[3] = "s".charCodeAt(0)
|
||||
|
||||
//version 2
|
||||
buff32[1] = 2;
|
||||
|
||||
//number of sections: 2
|
||||
buff32[2] = 2;
|
||||
|
||||
//id section 1
|
||||
buff32[3] = 1;
|
||||
|
||||
const n8 = this.n32*4;
|
||||
//id section 1 length in 64bytes
|
||||
const idSection1length = 8 + n8;
|
||||
const idSection1lengthHex = idSection1length.toString(16);
|
||||
buff32[4] = parseInt(idSection1lengthHex.slice(0,8), 16);
|
||||
buff32[5] = parseInt(idSection1lengthHex.slice(8,16), 16);
|
||||
|
||||
//this.n32
|
||||
buff32[6] = n8;
|
||||
|
||||
//prime number
|
||||
this.instance.exports.getRawPrime();
|
||||
|
||||
var pos = 7;
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
|
||||
// witness size
|
||||
buff32[pos] = this.witnessSize;
|
||||
pos++;
|
||||
|
||||
//id section 2
|
||||
buff32[pos] = 2;
|
||||
pos++;
|
||||
|
||||
// section 2 length
|
||||
const idSection2length = n8*this.witnessSize;
|
||||
const idSection2lengthHex = idSection2length.toString(16);
|
||||
buff32[pos] = parseInt(idSection2lengthHex.slice(0,8), 16);
|
||||
buff32[pos+1] = parseInt(idSection2lengthHex.slice(8,16), 16);
|
||||
|
||||
pos += 2;
|
||||
for (let i=0; i<this.witnessSize; i++) {
|
||||
this.instance.exports.getWitness(i);
|
||||
for (let j=0; j<this.n32; j++) {
|
||||
buff32[pos+j] = this.instance.exports.readSharedRWMemory(j);
|
||||
}
|
||||
pos += this.n32;
|
||||
}
|
||||
|
||||
return buff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
function toArray32(rem,size) {
|
||||
const res = []; //new Uint32Array(size); //has no unshift
|
||||
const radix = BigInt(0x100000000);
|
||||
while (rem) {
|
||||
res.unshift( Number(rem % radix));
|
||||
rem = rem / radix;
|
||||
}
|
||||
if (size) {
|
||||
var i = size - res.length;
|
||||
while (i>0) {
|
||||
res.unshift(0);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function fromArray32(arr) { //returns a BigInt
|
||||
var res = BigInt(0);
|
||||
const radix = BigInt(0x100000000);
|
||||
for (let i = 0; i<arr.length; i++) {
|
||||
res = res*radix + BigInt(arr[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function flatArray(a) {
|
||||
var res = [];
|
||||
fillArray(res, a);
|
||||
return res;
|
||||
|
||||
function fillArray(res, a) {
|
||||
if (Array.isArray(a)) {
|
||||
for (let i=0; i<a.length; i++) {
|
||||
fillArray(res, a[i]);
|
||||
}
|
||||
} else {
|
||||
res.push(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function fnvHash(str) {
|
||||
const uint64_max = BigInt(2) ** BigInt(64);
|
||||
let hash = BigInt("0xCBF29CE484222325");
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
hash ^= BigInt(str[i].charCodeAt());
|
||||
hash *= BigInt(0x100000001B3);
|
||||
hash %= uint64_max;
|
||||
}
|
||||
let shash = hash.toString(16);
|
||||
let n = 16 - shash.length;
|
||||
shash = '0'.repeat(n).concat(shash);
|
||||
return shash;
|
||||
}
|
||||
@@ -17,17 +17,17 @@ async function main(){
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const r1cs = fs.readFileSync("circuit.r1cs");
|
||||
const ptau = fs.readFileSync("powersOfTau28_hez_final_14.ptau");
|
||||
const r1cs = fs.readFileSync("circom/circuit.r1cs");
|
||||
const ptau = fs.readFileSync("circom/powersOfTau28_hez_final_14.ptau");
|
||||
|
||||
// snarkjs groth16 setup circuit.r1cs powersOfTau28_hez_final_14.ptau circuit_0000.zkey
|
||||
const zkey_0 = {type: "file", fileName: "circuit_0000.zkey"};
|
||||
const zkey_0 = {type: "file", fileName: "circom/circuit_0000.zkey"};
|
||||
await createOverride(zkey_0);
|
||||
console.log("groth16 setup...");
|
||||
await snarkjs.zKey.newZKey(r1cs, ptau, zkey_0);
|
||||
|
||||
// snarkjs zkey contribute circuit_0000.zkey circuit_final.zkey -e="<Notary's entropy>"
|
||||
const zkey_final = {type: "file", fileName: "circuit_final.zkey.notary"};
|
||||
const zkey_final = {type: "file", fileName: "circom/circuit_final.zkey.notary"};
|
||||
await createOverride(zkey_final);
|
||||
console.log("zkey contribute...");
|
||||
await snarkjs.zKey.contribute(zkey_0, zkey_final, "", entropy);
|
||||
@@ -36,7 +36,7 @@ async function main(){
|
||||
console.log("zkey export...");
|
||||
const vKey = await snarkjs.zKey.exportVerificationKey(zkey_final);
|
||||
// copied from snarkjs/cli.js zkeyExportVKey()
|
||||
await bfj.write("verification_key.json", stringifyBigInts(vKey), { space: 1 });
|
||||
await bfj.write("circom/verification_key.json", stringifyBigInts(vKey), { space: 1 });
|
||||
}
|
||||
|
||||
main().then(() => {
|
||||
8
circom/package.json
Normal file
8
circom/package.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"dependencies": {
|
||||
"circom": "^0.5.46",
|
||||
"circom2": "^0.2.5",
|
||||
"circomlibjs": "^0.1.7",
|
||||
"snarkjs": "^0.4.24"
|
||||
}
|
||||
}
|
||||
BIN
circom/powersOfTau28_hez_final_14.ptau
Normal file
BIN
circom/powersOfTau28_hez_final_14.ptau
Normal file
Binary file not shown.
@@ -13,7 +13,7 @@ async function main(){
|
||||
const proof_path = process.argv[4];
|
||||
|
||||
const input = fs.readFileSync(input_path);
|
||||
const wasm = fs.readFileSync(path.join("circuit_js", "circuit.wasm"));
|
||||
const wasm = fs.readFileSync(path.join("circom", "circuit_js", "circuit.wasm"));
|
||||
const zkey_final = fs.readFileSync(proving_key_path);
|
||||
|
||||
const in_json = JSON.parse(input);
|
||||
@@ -10,7 +10,7 @@ async function main(retval){
|
||||
const pub_path = process.argv[2];
|
||||
const proof_path = process.argv[3];
|
||||
|
||||
const vk = JSON.parse(fs.readFileSync("verification_key.json", "utf8"));
|
||||
const vk = JSON.parse(fs.readFileSync("circom/verification_key.json", "utf8"));
|
||||
const pub = JSON.parse(fs.readFileSync(pub_path, "utf8"));
|
||||
const proof = JSON.parse(fs.readFileSync(proof_path, "utf8"));
|
||||
|
||||
168
script.py
168
script.py
@@ -1,168 +0,0 @@
|
||||
import sys
|
||||
import random
|
||||
import os
|
||||
|
||||
def padded_hex(s):
|
||||
h = hex(s)
|
||||
l = len(h)
|
||||
if l % 2 == 0:
|
||||
return h
|
||||
else:
|
||||
return '0x{0:0{1}x}'.format(s,l-1)
|
||||
|
||||
# This script will generate the circom circuit and inputs
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
print('Expected 1 argument: amount of plaintext to process (in Field elements). Max 16')
|
||||
exit(1)
|
||||
count = int(sys.argv[1])
|
||||
if count > 16:
|
||||
print('Max 16 Field elements allowed. Got ', count)
|
||||
exit(1)
|
||||
|
||||
input = '{\n'
|
||||
input += '"plaintext_hash": "'+str(random.randint(0, 2**254))+'",\n'
|
||||
input += '"label_sum_hash": "'+str(random.randint(0, 2**254))+'",\n'
|
||||
input += '"sum_of_zero_labels": "'+str(random.randint(0, 2**140))+'",\n'
|
||||
input += '"plaintext": [\n'
|
||||
for c in range(0, count):
|
||||
input += ' "'+str(random.randint(0, 2**253))+'"'
|
||||
if c < count-1:
|
||||
input += ',\n'
|
||||
input += "],\n"
|
||||
input += '"delta": [\n'
|
||||
for c in range(0, count):
|
||||
input += ' [\n'
|
||||
for x in range(0, 254):
|
||||
input += ' "'+str(random.randint(0, 2**253))+'"'
|
||||
if x < 253:
|
||||
input += ',\n'
|
||||
input += ' ]\n'
|
||||
if c < count-1:
|
||||
input += ',\n'
|
||||
input += ']\n'
|
||||
input += '}\n'
|
||||
|
||||
with open('input.json', 'w') as f:
|
||||
f.write(input)
|
||||
|
||||
main = 'pragma circom 2.0.0;\n'
|
||||
main += 'include "./poseidon.circom";\n'
|
||||
main += 'include "./utils.circom";\n'
|
||||
main += '\n'
|
||||
|
||||
main += 'template Main() {\n'
|
||||
main += ' signal input plaintext_hash;\n'
|
||||
main += ' signal input label_sum_hash;\n'
|
||||
main += ' signal input plaintext['+str(count)+'];\n'
|
||||
main += ' signal input delta['+str(count)+'][253];\n'
|
||||
main += ' signal input sum_of_zero_labels;\n'
|
||||
main += ' signal sums['+str(count)+'];\n'
|
||||
main += '\n'
|
||||
|
||||
main += ' component hash = Poseidon('+str(count)+');\n'
|
||||
main += ' for (var i = 0; i<'+str(count)+'; i++) {\n'
|
||||
main += ' hash.inputs[i] <== plaintext[i];\n'
|
||||
main += ' }\n'
|
||||
main += '\n'
|
||||
|
||||
main += " // TODO to pass this assert we'd have to\n"
|
||||
main += ' // use actual values instead of random ones, so commenting out for now\n'
|
||||
main += ' plaintext_hash === hash.out;\n'
|
||||
main += '\n'
|
||||
|
||||
# check that Prover's hash is correct. hashing 16 field elements at a time since
|
||||
# idk how to chain hashes with circomlib.
|
||||
# Using prev. digest as the first input to the next hash
|
||||
|
||||
# if is_final is true then count includes the sum_of_labels
|
||||
# def hash(no, start, count, is_final=False):
|
||||
# out = ' component hash_'+str(no)+' = Poseidon('+str(count)+');\n'
|
||||
# if no > 0:
|
||||
# #first element is prev. hash digest
|
||||
# out += ' hash_'+str(no)+'.inputs[0] <== hash_'+str(no-1)+'.out;\n'
|
||||
# else:
|
||||
# if is_final and count == 1:
|
||||
# out += ' hash_'+str(no)+'.inputs[0] <== sum_of_labels;\n'
|
||||
# else:
|
||||
# out += ' hash_'+str(no)+'.inputs[0] <== plaintext['+str(start)+'];\n'
|
||||
# for x in range(1, count-1):
|
||||
# out += ' hash_'+str(no)+'.inputs['+str(x)+'] <== plaintext['+str(start+x)+'];\n'
|
||||
# if is_final:
|
||||
# # sum of labels if the last input
|
||||
# out += ' hash_'+str(no)+'.inputs['+str(count-1)+'] <== sum_of_labels;\n'
|
||||
# else:
|
||||
# out += ' hash_'+str(no)+'.inputs['+str(count-1)+'] <== plaintext['+str(start+count-1)+'];\n'
|
||||
# out += '\n'
|
||||
# return out
|
||||
|
||||
# def hash_str():
|
||||
# out = ''
|
||||
# if count+1 <= 16:
|
||||
# out += hash(0, 0, count+1, True)
|
||||
# out += ' prover_hash <== hash_0.out;\n'
|
||||
# return out
|
||||
# else:
|
||||
# out += hash(0, 0, 16, False)
|
||||
# if count+1 <= 32:
|
||||
# out += hash(1, 16, count+1-16, True)
|
||||
# out += ' prover_hash <== hash_1.out;\n'
|
||||
# return out
|
||||
# else:
|
||||
# out += hash(1, 16, 16, False)
|
||||
# if count+1 <= 48:
|
||||
# out += hash(2, 16, count+1-32, True)
|
||||
# out += ' prover_hash <== hash_2.out;\n'
|
||||
# return out
|
||||
# else:
|
||||
# out += hash(2, 16, 16, False)
|
||||
# if count+1 <= 64:
|
||||
# out += hash(3, 16, count+1-48, True)
|
||||
# out += ' prover_hash <== hash_3.out;\n'
|
||||
# return out
|
||||
# else:
|
||||
# out += hash(3, 16, 16, False)
|
||||
# if count+1 <= 80:
|
||||
# out += hash(4, 16, count+1-64, True)
|
||||
# out += ' prover_hash <== hash_3.out;\n'
|
||||
# return out
|
||||
# else:
|
||||
# out += hash(4, 16, 16, False)
|
||||
|
||||
|
||||
# main += '\n'
|
||||
# main += hash_str()
|
||||
# main += '\n'
|
||||
|
||||
main += ' component ip['+str(count)+'];\n'
|
||||
main += ' for (var i = 0; i<'+str(count)+'; i++) {\n'
|
||||
main += ' ip[i] = InnerProd();\n'
|
||||
main += ' ip[i].plaintext <== plaintext[i];\n'
|
||||
main += ' for (var j=0; j<253; j++) {\n'
|
||||
main += ' ip[i].deltas[j] <== delta[i][j];\n'
|
||||
main += ' }\n'
|
||||
main += ' sums[i] <== ip[i].out;\n'
|
||||
main += ' }\n'
|
||||
main += '\n'
|
||||
|
||||
main += ' signal sum_of_deltas <== '
|
||||
for c in range(0, count):
|
||||
main += 'sums['+str(c)+']'
|
||||
if c < count-1:
|
||||
main += ' + '
|
||||
else:
|
||||
main += ';\n'
|
||||
|
||||
main += " // TODO to pass this assert we'd have to\n"
|
||||
main += ' // use actual values instead of random ones, so commenting out for now\n'
|
||||
main += ' component ls_hash = Poseidon(1);\n'
|
||||
main += ' ls_hash.inputs[0] <== sum_of_zero_labels + sum_of_deltas;\n'
|
||||
main += ' label_sum_hash === ls_hash.out;\n'
|
||||
main += '}\n'
|
||||
|
||||
|
||||
main += 'component main {public [plaintext_hash, label_sum_hash, delta, sum_of_zero_labels]} = Main();'
|
||||
with open('test.circom', 'w') as f:
|
||||
f.write(main)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,111 +1,53 @@
|
||||
pub mod circuit;
|
||||
pub mod onetimesetup;
|
||||
pub mod poseidon_spec;
|
||||
pub mod poseidon;
|
||||
pub mod prover;
|
||||
pub mod utils;
|
||||
pub mod verifier;
|
||||
|
||||
use crate::halo2_backend::onetimesetup::OneTimeSetup;
|
||||
use crate::halo2_backend::utils::u8vec_to_boolvec;
|
||||
use rand::{thread_rng, Rng};
|
||||
#[test]
|
||||
fn e2e_test() {
|
||||
let mut rng = thread_rng();
|
||||
/// The amount of useful bits, see [crate::prover::Prove::useful_bits].
|
||||
/// This value is hard-coded into the circuit regardless of whether we use pasta
|
||||
/// curves (field size 255) or the bn254 curve (field size 254).
|
||||
const USEFUL_BITS: usize = 253;
|
||||
|
||||
let ots = OneTimeSetup::new();
|
||||
/// The size of the chunk, see [crate::prover::Prove::chunk_size].
|
||||
/// We use 14 field elements of 253 bits and 128 bits of the 15th field
|
||||
/// element: 14 * 253 + 128 == 3670 bits total. The low 125 bits
|
||||
/// of the last field element will be used for the salt.
|
||||
const CHUNK_SIZE: usize = 3670;
|
||||
|
||||
// The Prover should have generated the proving key (before the labelsum
|
||||
// protocol starts) like this:
|
||||
ots.setup().unwrap();
|
||||
let proving_key = ots.get_proving_key();
|
||||
/// The elliptic curve on which the Poseidon hash will be computed.
|
||||
pub enum Curve {
|
||||
PASTA,
|
||||
BN254,
|
||||
}
|
||||
|
||||
// generate random plaintext of random size up to 2000 bytes
|
||||
let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
|
||||
.take(thread_rng().gen_range(0..300))
|
||||
.collect();
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::onetimesetup::OneTimeSetup;
|
||||
use super::prover::Prover;
|
||||
use super::verifier::Verifier;
|
||||
use super::*;
|
||||
use crate::tests::fixtures::e2e_test;
|
||||
|
||||
// Normally, the Prover is expected to obtain her binary labels by
|
||||
// evaluating the garbled circuit.
|
||||
// To keep this test simple, we don't evaluate the gc, but we generate
|
||||
// all labels of the Verifier and give the Prover her active labels.
|
||||
let bit_size = plaintext.len() * 8;
|
||||
let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
|
||||
let mut delta: u128 = rng.gen();
|
||||
// set the last bit
|
||||
delta |= 1;
|
||||
for _ in 0..bit_size {
|
||||
let label_zero: u128 = rng.gen();
|
||||
all_binary_labels.push([label_zero, label_zero ^ delta]);
|
||||
#[test]
|
||||
/// Tests the whole authdecode protocol end-to-end
|
||||
fn halo2_e2e_test() {
|
||||
let mut prover_ots = OneTimeSetup::new();
|
||||
let mut verifier_ots = OneTimeSetup::new();
|
||||
|
||||
// The Prover should have generated the proving key (before the labelsum
|
||||
// protocol starts) like this:
|
||||
prover_ots.setup();
|
||||
let proving_key = prover_ots.get_proving_key();
|
||||
|
||||
// The Verifier should have generated the verifying key (before the labelsum
|
||||
// protocol starts) like this:
|
||||
verifier_ots.setup();
|
||||
let verification_key = verifier_ots.get_verification_key();
|
||||
|
||||
let prover = Box::new(Prover::new(proving_key));
|
||||
let verifier = Box::new(Verifier::new(verification_key, Curve::PASTA));
|
||||
e2e_test(prover, verifier);
|
||||
}
|
||||
let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
|
||||
|
||||
let verifier = LabelsumVerifier::new(
|
||||
all_binary_labels.clone(),
|
||||
Box::new(verifiernode::VerifierNode {}),
|
||||
);
|
||||
|
||||
let verifier = verifier.setup().unwrap();
|
||||
|
||||
let prover = LabelsumProver::new(
|
||||
proving_key,
|
||||
prime,
|
||||
plaintext,
|
||||
poseidon,
|
||||
Box::new(provernode::ProverNode {}),
|
||||
);
|
||||
|
||||
// Perform setup
|
||||
let prover = prover.setup().unwrap();
|
||||
|
||||
// Commitment to the plaintext is sent to the Notary
|
||||
let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
|
||||
|
||||
// Notary sends back encrypted arithm. labels.
|
||||
let (cipheretexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash);
|
||||
|
||||
// Hash commitment to the label_sum is sent to the Notary
|
||||
let (label_sum_hashes, prover) = prover
|
||||
.labelsum_commitment(cipheretexts, &prover_labels)
|
||||
.unwrap();
|
||||
|
||||
// Notary sends the arithmetic label seed
|
||||
let (seed, verifier) = verifier.receive_labelsum_hashes(label_sum_hashes);
|
||||
|
||||
// At this point the following happens in the `committed GC` protocol:
|
||||
// - the Notary reveals the GC seed
|
||||
// - the User checks that the GC was created from that seed
|
||||
// - the User checks that her active output labels correspond to the
|
||||
// output labels derived from the seed
|
||||
// - we are called with the result of the check and (if successful)
|
||||
// with all the output labels
|
||||
|
||||
let prover = prover
|
||||
.binary_labels_authenticated(true, Some(all_binary_labels))
|
||||
.unwrap();
|
||||
|
||||
// Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
|
||||
let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
|
||||
|
||||
// Prover generates the proof
|
||||
let (proofs, prover) = prover.create_zk_proof().unwrap();
|
||||
|
||||
// Notary verifies the proof
|
||||
let verifier = verifier.verify_many(proofs).unwrap();
|
||||
assert_eq!(
|
||||
type_of(&verifier),
|
||||
"labelsum::verifier::LabelsumVerifier<labelsum::verifier::VerificationSuccessfull>"
|
||||
);
|
||||
}
|
||||
|
||||
/// Unzips a slice of pairs, returning items corresponding to choice
|
||||
fn choose<T: Clone>(items: &[[T; 2]], choice: &[bool]) -> Vec<T> {
|
||||
assert!(items.len() == choice.len(), "arrays are different length");
|
||||
items
|
||||
.iter()
|
||||
.zip(choice)
|
||||
.map(|(items, choice)| items[*choice as usize].clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn type_of<T>(_: &T) -> &'static str {
|
||||
std::any::type_name::<T>()
|
||||
}
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
use super::circuit::{LabelsumCircuit, K, USEFUL_ROWS};
|
||||
use super::circuit::{LabelsumCircuit, CELLS_PER_ROW, K, USEFUL_ROWS};
|
||||
use super::prover::PK;
|
||||
use super::verifier::VK;
|
||||
use halo2_proofs::plonk;
|
||||
use halo2_proofs::plonk::ProvingKey;
|
||||
use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use pasta_curves::pallas::Base as F;
|
||||
use pasta_curves::EqAffine;
|
||||
|
||||
pub struct OneTimeSetup {
|
||||
proving_key: Option<ProvingKey<EqAffine>>,
|
||||
verification_key: Option<VerifyingKey<EqAffine>>,
|
||||
proving_key: Option<PK>,
|
||||
verification_key: Option<VK>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
FileDoesNotExist,
|
||||
SnarkjsError,
|
||||
}
|
||||
|
||||
// OneTimeSetup should be run when Notary starts. It generates a proving and
|
||||
// a verification keys.
|
||||
// Note that currently halo2 does not support serializing the proving/verification
|
||||
// keys. That's why we can't use cached keys but need to re-generate them every time.
|
||||
/// OneTimeSetup generates the proving key and the verification key. It can be
|
||||
/// ahead of time before the actual zk proving/verification takes place.
|
||||
///
|
||||
/// Note that as of Oct 2022 halo2 does not support serializing the proving/verification
|
||||
/// keys. That's why we can't use cached keys but need to call this one-time setup every
|
||||
/// time when we instantiate the halo2 prover/verifier.
|
||||
impl OneTimeSetup {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -29,27 +24,32 @@ impl OneTimeSetup {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn setup(&self) -> Result<(), Error> {
|
||||
pub fn setup(&mut self) {
|
||||
let params: Params<EqAffine> = Params::new(K);
|
||||
// we need an instance of the circuit, the exact inputs don't matter
|
||||
let dummy1 = [F::from(0); 15];
|
||||
let dummy2: [Vec<F>; USEFUL_ROWS] = (0..USEFUL_ROWS)
|
||||
.map(|_| vec![F::from(0)])
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let circuit = LabelsumCircuit::new(dummy1, dummy2);
|
||||
let circuit = LabelsumCircuit::new(
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
[[Default::default(); CELLS_PER_ROW]; USEFUL_ROWS],
|
||||
);
|
||||
|
||||
// safe to unwrap, we are inputting the same params and circuit on every
|
||||
// invocation
|
||||
let vk = plonk::keygen_vk(¶ms, &circuit).unwrap();
|
||||
let pk = plonk::keygen_pk(¶ms, vk.clone(), &circuit).unwrap();
|
||||
|
||||
self.proving_key = Some(pk);
|
||||
self.verification_key = Some(vk);
|
||||
|
||||
Ok(())
|
||||
self.proving_key = Some(PK {
|
||||
key: pk,
|
||||
params: params.clone(),
|
||||
});
|
||||
self.verification_key = Some(VK { key: vk, params });
|
||||
}
|
||||
|
||||
pub fn get_proving_key(&self) -> ProvingKey<EqAffine> {
|
||||
self.proving_key.unwrap()
|
||||
pub fn get_proving_key(&self) -> PK {
|
||||
self.proving_key.as_ref().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn get_verification_key(&self) -> VK {
|
||||
self.verification_key.as_ref().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
130
src/halo2_backend/poseidon.rs
Normal file
130
src/halo2_backend/poseidon.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use group::ff::Field;
|
||||
use halo2_gadgets::poseidon::primitives::Spec;
|
||||
use halo2_gadgets::poseidon::primitives::{self as poseidon, ConstantLength};
|
||||
use halo2_gadgets::poseidon::Pow5Chip;
|
||||
use halo2_gadgets::poseidon::Pow5Config;
|
||||
use halo2_proofs::plonk::ConstraintSystem;
|
||||
use pasta_curves::pallas::Base as F;
|
||||
use pasta_curves::Fp;
|
||||
|
||||
/// Spec for rate 15 Poseidon which halo2 uses both inside
|
||||
/// the zk circuit and in the clear.
|
||||
///
|
||||
/// Compare it to the spec which zcash uses:
|
||||
/// [halo2_gadgets::poseidon::primitives::P128Pow5T3]
|
||||
#[derive(Debug)]
|
||||
pub struct Spec15;
|
||||
|
||||
impl Spec<Fp, 16, 15> for Spec15 {
|
||||
fn full_rounds() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
fn partial_rounds() -> usize {
|
||||
56
|
||||
}
|
||||
|
||||
fn sbox(val: Fp) -> Fp {
|
||||
val.pow_vartime(&[5])
|
||||
}
|
||||
|
||||
/// TODO: waiting on a definitive answer if returning 0 here is safe
|
||||
/// https://github.com/zcash/halo2/issues/674
|
||||
fn secure_mds() -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Spec for rate 1 Poseidon which halo2 uses both inside
|
||||
/// the zk circuit and in the clear.
|
||||
///
|
||||
/// Compare it to the spec which zcash uses:
|
||||
/// [halo2_gadgets::poseidon::primitives::P128Pow5T3]
|
||||
#[derive(Debug)]
|
||||
pub struct Spec1;
|
||||
|
||||
impl Spec<Fp, 2, 1> for Spec1 {
|
||||
fn full_rounds() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
fn partial_rounds() -> usize {
|
||||
56
|
||||
}
|
||||
|
||||
fn sbox(val: Fp) -> Fp {
|
||||
val.pow_vartime(&[5])
|
||||
}
|
||||
|
||||
fn secure_mds() -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Hashes inputs with rate 15 Poseidon and returns the digest
|
||||
///
|
||||
/// Patterned after [halo2_gadgets::poseidon::pow5]
|
||||
/// (see in that file tests::poseidon_hash())
|
||||
pub fn poseidon_15(field_elements: &[F; 15]) -> F {
|
||||
poseidon::Hash::<F, Spec15, ConstantLength<15>, 16, 15>::init().hash(*field_elements)
|
||||
}
|
||||
|
||||
/// Hashes inputs with rate 1 Poseidon and returns the digest
|
||||
///
|
||||
/// Patterned after [halo2_gadgets::poseidon::pow5]
|
||||
/// (see in that file tests::poseidon_hash())
|
||||
pub fn poseidon_1(field_elements: &[F; 1]) -> F {
|
||||
poseidon::Hash::<F, Spec1, ConstantLength<1>, 2, 1>::init().hash(*field_elements)
|
||||
}
|
||||
|
||||
/// Configures the in-circuit Poseidon for rate 15 and returns the config
|
||||
///
|
||||
/// Patterned after [halo2_gadgets::poseidon::pow5]
|
||||
/// (see in that file tests::impl Circuit for HashCircuit::configure())
|
||||
pub fn configure_poseidon_rate_15<S: Spec<F, 16, 15>>(
|
||||
rate: usize,
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
) -> Pow5Config<Fp, 16, 15> {
|
||||
let width = rate + 1;
|
||||
let state = (0..width).map(|_| meta.advice_column()).collect::<Vec<_>>();
|
||||
let partial_sbox = meta.advice_column();
|
||||
|
||||
let rc_a = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
Pow5Chip::configure::<S>(
|
||||
meta,
|
||||
state.try_into().unwrap(),
|
||||
partial_sbox,
|
||||
rc_a.try_into().unwrap(),
|
||||
rc_b.try_into().unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Configures the in-circuit Poseidon for rate 1 and returns the config
|
||||
///
|
||||
/// Patterned after [halo2_gadgets::poseidon::pow5]
|
||||
/// (see in that file tests::impl Circuit for HashCircuit::configure())
|
||||
pub fn configure_poseidon_rate_1<S: Spec<F, 2, 1>>(
|
||||
rate: usize,
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
) -> Pow5Config<Fp, 2, 1> {
|
||||
let width = rate + 1;
|
||||
let state = (0..width).map(|_| meta.advice_column()).collect::<Vec<_>>();
|
||||
let partial_sbox = meta.advice_column();
|
||||
|
||||
let rc_a = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..width).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
Pow5Chip::configure::<S>(
|
||||
meta,
|
||||
state.try_into().unwrap(),
|
||||
partial_sbox,
|
||||
rc_a.try_into().unwrap(),
|
||||
rc_b.try_into().unwrap(),
|
||||
)
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
use group::ff::Field;
|
||||
/// Specs which halo2 uses to compute a Poseidon hash both inside the zk
|
||||
/// circuit and in the clear.
|
||||
///
|
||||
///
|
||||
use halo2_gadgets::poseidon::primitives::Spec;
|
||||
use pasta_curves::Fp;
|
||||
|
||||
// Poseidon spec for 15-rate Poseidon
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Spec15;
|
||||
|
||||
impl Spec<Fp, 16, 15> for Spec15 {
|
||||
fn full_rounds() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
fn partial_rounds() -> usize {
|
||||
56
|
||||
}
|
||||
|
||||
fn sbox(val: Fp) -> Fp {
|
||||
val.pow_vartime(&[5])
|
||||
}
|
||||
|
||||
fn secure_mds() -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
// Poseidon spec for 1-rate Poseidon
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Spec1;
|
||||
|
||||
impl Spec<Fp, 2, 1> for Spec1 {
|
||||
fn full_rounds() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
fn partial_rounds() -> usize {
|
||||
56
|
||||
}
|
||||
|
||||
fn sbox(val: Fp) -> Fp {
|
||||
val.pow_vartime(&[5])
|
||||
}
|
||||
|
||||
fn secure_mds() -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
@@ -1,54 +1,47 @@
|
||||
use super::circuit::{LabelsumCircuit, CELLS_PER_ROW, FULL_FIELD_ELEMENTS, K, USEFUL_ROWS};
|
||||
use super::poseidon_spec::{Spec1, Spec15};
|
||||
use super::utils::boolvec_to_u8vec;
|
||||
use super::utils::{bigint_to_f, f_to_bigint};
|
||||
use crate::prover::{ProofInput, Prove, ProverError, ProvingKeyTrait};
|
||||
use halo2_gadgets::poseidon::{
|
||||
primitives::{self as poseidon, ConstantLength, Spec},
|
||||
Hash, Pow5Chip, Pow5Config,
|
||||
use super::circuit::{
|
||||
LabelsumCircuit, CELLS_PER_ROW, SALT_SIZE, TOTAL_FIELD_ELEMENTS, USEFUL_ROWS,
|
||||
};
|
||||
use halo2_proofs::arithmetic::FieldExt;
|
||||
use super::poseidon::{poseidon_1, poseidon_15};
|
||||
use super::utils::{bigint_to_f, deltas_to_matrices, f_to_bigint};
|
||||
use super::{CHUNK_SIZE, USEFUL_BITS};
|
||||
use crate::prover::{ProofInput, Prove, ProverError};
|
||||
use halo2_proofs::plonk;
|
||||
use halo2_proofs::plonk::ProvingKey;
|
||||
use halo2_proofs::plonk::SingleVerifier;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::transcript::Blake2bRead;
|
||||
use halo2_proofs::transcript::Blake2bWrite;
|
||||
use halo2_proofs::transcript::Challenge255;
|
||||
use halo2_proofs::transcript::{Blake2bWrite, Challenge255};
|
||||
use instant::Instant;
|
||||
use num::BigUint;
|
||||
use pasta_curves::pallas::Base as F;
|
||||
use pasta_curves::EqAffine;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand::thread_rng;
|
||||
|
||||
// halo2's native ProvingKey can't be used without params, so we wrap
|
||||
// them in one struct.
|
||||
/// halo2's native ProvingKey can't be used without params, so we wrap
|
||||
/// them in one struct.
|
||||
#[derive(Clone)]
|
||||
pub struct PK {
|
||||
key: ProvingKey<EqAffine>,
|
||||
params: Params<EqAffine>,
|
||||
pub key: ProvingKey<EqAffine>,
|
||||
pub params: Params<EqAffine>,
|
||||
}
|
||||
impl ProvingKeyTrait for PK {}
|
||||
|
||||
pub struct Prover {}
|
||||
pub struct Prover {
|
||||
proving_key: PK,
|
||||
}
|
||||
|
||||
impl Prove for Prover {
|
||||
fn prove(&self, input: ProofInput, proving_key: PK) -> Result<Vec<u8>, ProverError> {
|
||||
// convert each delta into a field element type
|
||||
let deltas: Vec<F> = input.deltas.iter().map(|d| bigint_to_f(d)).collect();
|
||||
fn prove(&self, input: ProofInput) -> Result<Vec<u8>, ProverError> {
|
||||
if input.deltas.len() != self.chunk_size() || input.plaintext.len() != TOTAL_FIELD_ELEMENTS
|
||||
{
|
||||
// this can only be caused by an error in
|
||||
// `crate::prover::LabelsumProver` logic
|
||||
return Err(ProverError::InternalError);
|
||||
}
|
||||
|
||||
// to make handling simpler, we pad each set of 253 deltas
|
||||
// with 3 zero deltas on the left.
|
||||
let deltas: Vec<F> = deltas
|
||||
.chunks(self.useful_bits())
|
||||
.map(|c| {
|
||||
let mut v = vec![F::from(0); 3];
|
||||
v.extend(c.to_vec());
|
||||
v
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
// convert into matrices
|
||||
let (deltas_as_rows, deltas_as_columns) =
|
||||
deltas_to_matrices(&input.deltas, self.useful_bits());
|
||||
|
||||
// convert plaintext into field element type
|
||||
let plaintext: [F; 15] = input
|
||||
// convert plaintext into F type
|
||||
let plaintext: [F; TOTAL_FIELD_ELEMENTS] = input
|
||||
.plaintext
|
||||
.iter()
|
||||
.map(|bigint| bigint_to_f(bigint))
|
||||
@@ -56,100 +49,53 @@ impl Prove for Prover {
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
// number of chunks should be equal to USEFUL_ROWS
|
||||
let all_deltas: [Vec<F>; USEFUL_ROWS] = deltas
|
||||
.chunks(CELLS_PER_ROW)
|
||||
.map(|c| c.to_vec())
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
// transpose to make CELLS_PER_ROW instance columns
|
||||
let input_deltas: [Vec<F>; CELLS_PER_ROW] = (0..CELLS_PER_ROW)
|
||||
.map(|i| {
|
||||
all_deltas
|
||||
.iter()
|
||||
.map(|inner| inner[i].clone())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let circuit = LabelsumCircuit::new(plaintext, all_deltas.clone());
|
||||
|
||||
use instant::Instant;
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let params = proving_key;
|
||||
|
||||
let params: Params<EqAffine> = Params::new(K);
|
||||
|
||||
let vk = plonk::keygen_vk(¶ms, &circuit).unwrap();
|
||||
let pk = plonk::keygen_pk(¶ms, vk.clone(), &circuit).unwrap();
|
||||
|
||||
println!("ProvingKey built [{:?}]", now.elapsed());
|
||||
|
||||
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
|
||||
|
||||
let mut all_inputs: Vec<&[F]> = Vec::new();
|
||||
for i in 0..input_deltas.len() {
|
||||
let d = input_deltas[i].as_slice();
|
||||
all_inputs.push(d);
|
||||
}
|
||||
// arrange into the format which halo2 expects
|
||||
let mut all_inputs: Vec<&[F]> = deltas_as_columns.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
// add another column with public inputs
|
||||
let tmp = &[
|
||||
bigint_to_f(&input.plaintext_hash),
|
||||
bigint_to_f(&input.label_sum_hash),
|
||||
bigint_to_f(&input.sum_of_zero_labels),
|
||||
];
|
||||
all_inputs.push(tmp);
|
||||
println!("{:?} all inputs len", all_inputs.len());
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
// prepare the proving system and generate the proof:
|
||||
|
||||
let circuit =
|
||||
LabelsumCircuit::new(plaintext, bigint_to_f(&input.salt), deltas_as_rows.into());
|
||||
|
||||
let params = &self.proving_key.params;
|
||||
let pk = &self.proving_key.key;
|
||||
|
||||
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
plonk::create_proof(
|
||||
¶ms,
|
||||
&pk,
|
||||
params,
|
||||
pk,
|
||||
&[circuit],
|
||||
&[all_inputs.as_slice()],
|
||||
&mut rng,
|
||||
&mut transcript,
|
||||
)
|
||||
.unwrap();
|
||||
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
|
||||
|
||||
println!("Proof created [{:?}]", now.elapsed());
|
||||
|
||||
let proof = transcript.finalize();
|
||||
|
||||
let now = Instant::now();
|
||||
let strategy = SingleVerifier::new(¶ms);
|
||||
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
|
||||
|
||||
plonk::verify_proof(
|
||||
¶ms,
|
||||
&vk,
|
||||
strategy,
|
||||
&[all_inputs.as_slice()],
|
||||
&mut transcript,
|
||||
)
|
||||
.unwrap();
|
||||
//console::log_1(&format!("Proof verified {:?}", now.elapsed()).into());
|
||||
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
|
||||
|
||||
println!("Proof verified [{:?}]", now.elapsed());
|
||||
println!("Proof size [{} kB]", proof.len() as f64 / 1024.0);
|
||||
Ok(proof)
|
||||
}
|
||||
|
||||
fn useful_bits(&self) -> usize {
|
||||
253
|
||||
USEFUL_BITS
|
||||
}
|
||||
|
||||
fn poseidon_rate(&self) -> usize {
|
||||
15
|
||||
TOTAL_FIELD_ELEMENTS
|
||||
}
|
||||
|
||||
fn permutation_count(&self) -> usize {
|
||||
@@ -157,19 +103,16 @@ impl Prove for Prover {
|
||||
}
|
||||
|
||||
fn salt_size(&self) -> usize {
|
||||
125
|
||||
SALT_SIZE
|
||||
}
|
||||
|
||||
// we have 14 field elements of 253 bits and 128 bits of the 15th field
|
||||
// element, i.e. 14*253+128==3670 bits. The least 125 bits of the last
|
||||
// field element will be used for the salt.
|
||||
fn chunk_size(&self) -> usize {
|
||||
3670
|
||||
CHUNK_SIZE
|
||||
}
|
||||
|
||||
// Hashes inputs with Poseidon and returns the digest.
|
||||
/// Hashes `inputs` with Poseidon and returns the digest as `BigUint`.
|
||||
fn hash(&self, inputs: &Vec<BigUint>) -> Result<BigUint, ProverError> {
|
||||
let d = match inputs.len() {
|
||||
let digest = match inputs.len() {
|
||||
15 => {
|
||||
// hash with rate-15 Poseidon
|
||||
let fes: [F; 15] = inputs
|
||||
@@ -178,7 +121,7 @@ impl Prove for Prover {
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
poseidon::Hash::<F, Spec15, ConstantLength<15>, 16, 15>::init().hash(fes)
|
||||
poseidon_15(&fes)
|
||||
}
|
||||
1 => {
|
||||
// hash with rate-1 Poseidon
|
||||
@@ -188,179 +131,16 @@ impl Prove for Prover {
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
poseidon::Hash::<F, Spec1, ConstantLength<1>, 2, 1>::init().hash(fes)
|
||||
poseidon_1(&fes)
|
||||
}
|
||||
_ => return Err(ProverError::WrongPoseidonInput),
|
||||
};
|
||||
Ok(f_to_bigint(&d))
|
||||
Ok(f_to_bigint(&digest))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn maintest() {
|
||||
use super::circuit::{LabelsumCircuit, CELLS_PER_ROW, FULL_FIELD_ELEMENTS, K, USEFUL_ROWS};
|
||||
use super::utils::boolvec_to_u8vec;
|
||||
use halo2_proofs::arithmetic::FieldExt;
|
||||
use halo2_proofs::plonk;
|
||||
use halo2_proofs::plonk::SingleVerifier;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::transcript::Blake2bRead;
|
||||
use halo2_proofs::transcript::Blake2bWrite;
|
||||
use halo2_proofs::transcript::Challenge255;
|
||||
use pasta_curves::EqAffine;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// generate random plaintext to fill all cells. The first 3 bits of each
|
||||
// full field element are not used, so we zero them out
|
||||
const TOTAL_PLAINTEXT_SIZE: usize = CELLS_PER_ROW * USEFUL_ROWS;
|
||||
let mut plaintext_bits: [bool; TOTAL_PLAINTEXT_SIZE] =
|
||||
core::iter::repeat_with(|| rng.gen::<bool>())
|
||||
.take(TOTAL_PLAINTEXT_SIZE)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
for i in 0..FULL_FIELD_ELEMENTS {
|
||||
plaintext_bits[256 * i + 0] = false;
|
||||
plaintext_bits[256 * i + 1] = false;
|
||||
plaintext_bits[256 * i + 2] = false;
|
||||
impl Prover {
|
||||
pub fn new(pk: PK) -> Self {
|
||||
Self { proving_key: pk }
|
||||
}
|
||||
|
||||
// random deltas. The first 3 deltas of each set of 256 are not used, so we
|
||||
// zero them out.
|
||||
let mut deltas: [F; TOTAL_PLAINTEXT_SIZE] =
|
||||
core::iter::repeat_with(|| F::from_u128(rng.gen::<u128>()))
|
||||
.take(TOTAL_PLAINTEXT_SIZE)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
for i in 0..FULL_FIELD_ELEMENTS {
|
||||
deltas[256 * i + 0] = F::from(0);
|
||||
deltas[256 * i + 1] = F::from(0);
|
||||
deltas[256 * i + 2] = F::from(0);
|
||||
}
|
||||
|
||||
let pt_chunks = plaintext_bits.chunks(256);
|
||||
// plaintext has 15 BigUint field elements - this is how the User is
|
||||
// expected to call this halo2 prover
|
||||
let plaintext: [BigUint; FULL_FIELD_ELEMENTS + 1] = pt_chunks
|
||||
.map(|c| {
|
||||
// convert each chunk of bits into a field element
|
||||
BigUint::from_bytes_be(&boolvec_to_u8vec(c))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
// number of chunks should be equal to USEFUL_ROWS
|
||||
let all_deltas: [Vec<F>; USEFUL_ROWS] = deltas
|
||||
.chunks(CELLS_PER_ROW)
|
||||
.map(|c| c.to_vec())
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
// transpose to make CELLS_PER_ROW instance columns
|
||||
let input_deltas: [Vec<F>; CELLS_PER_ROW] = (0..CELLS_PER_ROW)
|
||||
.map(|i| {
|
||||
all_deltas
|
||||
.iter()
|
||||
.map(|inner| inner[i].clone())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let mut hash_input = [F::default(); 15];
|
||||
for i in 0..hash_input.len() {
|
||||
hash_input[i] = bigint_to_f(&plaintext[i]);
|
||||
}
|
||||
let plaintext_digest =
|
||||
poseidon::Hash::<F, Spec15, ConstantLength<15>, 16, 15>::init().hash(hash_input);
|
||||
|
||||
// compute labelsum digest
|
||||
let mut labelsum = F::from(0);
|
||||
for it in plaintext_bits.iter().zip(deltas.clone()) {
|
||||
let (p, d) = it;
|
||||
let dot_product = F::from(*p) * d;
|
||||
labelsum += dot_product;
|
||||
}
|
||||
let labelsum_digest =
|
||||
poseidon::Hash::<F, Spec1, ConstantLength<1>, 2, 1>::init().hash([labelsum]);
|
||||
|
||||
let circuit = LabelsumCircuit::new(plaintext, all_deltas.clone());
|
||||
|
||||
// let prover = MockProver::<pallas::Base>::run(k, &circuit, input_deltas.clone()).unwrap();
|
||||
|
||||
// assert_eq!(prover.verify(), Ok(()));
|
||||
|
||||
use instant::Instant;
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let params: Params<EqAffine> = Params::new(K);
|
||||
|
||||
let vk = plonk::keygen_vk(¶ms, &circuit).unwrap();
|
||||
let pk = plonk::keygen_pk(¶ms, vk.clone(), &circuit).unwrap();
|
||||
|
||||
//console::log_1(&format!("ProvingKey built {:?}", now.elapsed()).into());
|
||||
println!("ProvingKey built [{:?}]", now.elapsed());
|
||||
|
||||
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
|
||||
|
||||
let mut all_inputs: Vec<&[F]> = Vec::new();
|
||||
for i in 0..input_deltas.len() {
|
||||
let d = input_deltas[i].as_slice();
|
||||
all_inputs.push(d);
|
||||
}
|
||||
|
||||
let tmp = &[plaintext_digest, labelsum_digest];
|
||||
all_inputs.push(tmp);
|
||||
println!("{:?} all inputs len", all_inputs.len());
|
||||
|
||||
let now = Instant::now();
|
||||
plonk::create_proof(
|
||||
¶ms,
|
||||
&pk,
|
||||
&[circuit],
|
||||
&[all_inputs.as_slice()],
|
||||
&mut rng,
|
||||
&mut transcript,
|
||||
)
|
||||
.unwrap();
|
||||
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
|
||||
|
||||
println!("Proof created [{:?}]", now.elapsed());
|
||||
|
||||
let proof = transcript.finalize();
|
||||
|
||||
let now = Instant::now();
|
||||
let strategy = SingleVerifier::new(¶ms);
|
||||
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
|
||||
|
||||
plonk::verify_proof(
|
||||
¶ms,
|
||||
&vk,
|
||||
strategy,
|
||||
&[all_inputs.as_slice()],
|
||||
&mut transcript,
|
||||
)
|
||||
.unwrap();
|
||||
//console::log_1(&format!("Proof verified {:?}", now.elapsed()).into());
|
||||
//console::log_1(&format!("Proof created {:?}", now.elapsed()).into());
|
||||
|
||||
println!("Proof verified [{:?}]", now.elapsed());
|
||||
println!("Proof size [{} kB]", proof.len() as f64 / 1024.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
use num::FromPrimitive;
|
||||
|
||||
let two = BigUint::from_u8(2).unwrap();
|
||||
let pow_2_64 = two.pow(64);
|
||||
let pow_2_128 = two.pow(128);
|
||||
let pow_2_192 = two.pow(192);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
use super::circuit::{CELLS_PER_ROW, USEFUL_ROWS};
|
||||
use crate::utils::{boolvec_to_u8vec, u8vec_to_boolvec};
|
||||
use crate::Delta;
|
||||
use halo2_proofs::arithmetic::FieldExt;
|
||||
use num::{BigUint, FromPrimitive};
|
||||
use pasta_curves::Fp as F;
|
||||
|
||||
// Decomposes a `BigUint` into bits and returns the bits in BE bit order,
|
||||
// left padding them with zeroes to the size of 256.
|
||||
pub fn bigint_to_bits(bigint: BigUint) -> [bool; 256] {
|
||||
/// Decomposes a `BigUint` into bits and returns the bits in MSB-first bit order,
|
||||
/// left padding them with zeroes to the size of 256.
|
||||
pub fn bigint_to_256bits(bigint: BigUint) -> [bool; 256] {
|
||||
let bits = u8vec_to_boolvec(&bigint.to_bytes_be());
|
||||
let mut bits256 = vec![false; 256];
|
||||
bits256[256 - bits.len()..].copy_from_slice(&bits);
|
||||
bits256.try_into().unwrap()
|
||||
}
|
||||
|
||||
/// Converts a `BigUint` into an field element type.
|
||||
/// The assumption is that `bigint` was sanitized earlier and is not larger
|
||||
/// than [crate::verifier::Verify::field_size]
|
||||
pub fn bigint_to_f(bigint: &BigUint) -> F {
|
||||
let le = bigint.to_bytes_le();
|
||||
let mut wide = [0u8; 64];
|
||||
@@ -18,13 +24,14 @@ pub fn bigint_to_f(bigint: &BigUint) -> F {
|
||||
F::from_bytes_wide(&wide)
|
||||
}
|
||||
|
||||
/// Converts `F` into a `BigUint` type
|
||||
pub fn f_to_bigint(f: &F) -> BigUint {
|
||||
let tmp: [u8; 32] = f.try_into().unwrap();
|
||||
BigUint::from_bytes_le(&tmp)
|
||||
}
|
||||
|
||||
// Splits up 256 bits into 4 limbs, shifts each limb left
|
||||
// and returns the shifted limbs as `BigUint`s.
|
||||
/// Splits up 256 bits into 4 limbs, shifts each limb left
|
||||
/// and returns the shifted limbs as `BigUint`s.
|
||||
pub fn bits_to_limbs(bits: [bool; 256]) -> [BigUint; 4] {
|
||||
// break up the field element into 4 64-bit limbs
|
||||
// the limb at index 0 is the high limb
|
||||
@@ -35,8 +42,10 @@ pub fn bits_to_limbs(bits: [bool; 256]) -> [BigUint; 4] {
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
// shift each limb to the left
|
||||
// shift each limb to the left:
|
||||
|
||||
let two = BigUint::from_u8(2).unwrap();
|
||||
// how many bits to shift each limb by
|
||||
let shift_by: [BigUint; 4] = [192, 128, 64, 0]
|
||||
.iter()
|
||||
.map(|s| two.pow(*s))
|
||||
@@ -52,31 +61,82 @@ pub fn bits_to_limbs(bits: [bool; 256]) -> [BigUint; 4] {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn u8vec_to_boolvec(v: &[u8]) -> Vec<bool> {
|
||||
let mut bv = Vec::with_capacity(v.len() * 8);
|
||||
for byte in v.iter() {
|
||||
for i in 0..8 {
|
||||
bv.push(((byte >> (7 - i)) & 1) != 0);
|
||||
}
|
||||
}
|
||||
bv
|
||||
/// Converts a vec of padded deltas into a matrix of rows and a matrix of
|
||||
/// columns and returns them.
|
||||
pub fn deltas_to_matrices(
|
||||
deltas: &Vec<Delta>,
|
||||
useful_bits: usize,
|
||||
) -> (
|
||||
[[F; CELLS_PER_ROW]; USEFUL_ROWS],
|
||||
[[F; USEFUL_ROWS]; CELLS_PER_ROW],
|
||||
) {
|
||||
let deltas = convert_and_pad_deltas(deltas, useful_bits);
|
||||
let deltas_as_rows = deltas_to_matrix_of_rows(&deltas, useful_bits);
|
||||
|
||||
let deltas_as_columns = transpose_rows(&deltas_as_rows);
|
||||
|
||||
(deltas_as_rows, deltas_as_columns)
|
||||
}
|
||||
|
||||
// Convert bits into bytes. The bits will be left-padded with zeroes to the
|
||||
// multiple of 8.
|
||||
#[inline]
|
||||
pub fn boolvec_to_u8vec(bv: &[bool]) -> Vec<u8> {
|
||||
let rem = bv.len() % 8;
|
||||
let first_byte_bitsize = if rem == 0 { 8 } else { rem };
|
||||
let offset = if rem == 0 { 0 } else { 1 };
|
||||
let mut v = vec![0u8; bv.len() / 8 + offset];
|
||||
// implicitely left-pad the first byte with zeroes
|
||||
for (i, b) in bv[0..first_byte_bitsize].iter().enumerate() {
|
||||
v[i / 8] |= (*b as u8) << (first_byte_bitsize - 1 - i);
|
||||
}
|
||||
for (i, b) in bv[first_byte_bitsize..].iter().enumerate() {
|
||||
v[1 + i / 8] |= (*b as u8) << (7 - (i % 8));
|
||||
}
|
||||
v
|
||||
/// To make handling inside the circuit simpler, we pad each chunk (except for
|
||||
/// the last one) of deltas with zero values on the left to the size 256.
|
||||
/// Note that the last chunk (corresponding to the 15th field element) will
|
||||
/// contain only 128 deltas, so we do NOT pad it.
|
||||
///
|
||||
/// Returns padded deltas
|
||||
fn convert_and_pad_deltas(deltas: &Vec<Delta>, useful_bits: usize) -> Vec<F> {
|
||||
// convert deltas into F type
|
||||
let deltas: Vec<F> = deltas.iter().map(|d| bigint_to_f(d)).collect();
|
||||
|
||||
deltas
|
||||
.chunks(useful_bits)
|
||||
.enumerate()
|
||||
.map(|(i, c)| {
|
||||
if i < 14 {
|
||||
let mut v = vec![F::from(0); 256 - c.len()];
|
||||
v.extend(c.to_vec());
|
||||
v
|
||||
} else {
|
||||
c.to_vec()
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Converts a vec of padded deltas into a matrix of rows and returns it.
|
||||
fn deltas_to_matrix_of_rows(
|
||||
deltas: &Vec<F>,
|
||||
useful_bits: usize,
|
||||
) -> ([[F; CELLS_PER_ROW]; USEFUL_ROWS]) {
|
||||
deltas
|
||||
.chunks(CELLS_PER_ROW)
|
||||
.map(|c| c.try_into().unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Transposes a matrix of rows.
|
||||
fn transpose_rows(matrix: &[[F; CELLS_PER_ROW]; USEFUL_ROWS]) -> [[F; USEFUL_ROWS]; CELLS_PER_ROW] {
|
||||
(0..CELLS_PER_ROW)
|
||||
.map(|i| {
|
||||
matrix
|
||||
.iter()
|
||||
.map(|inner| inner[i].clone().try_into().unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Converts a vec of deltas into a matrix of columns and returns it.
|
||||
fn deltas_to_matrix_of_columns(
|
||||
deltas: &Vec<F>,
|
||||
useful_bits: usize,
|
||||
) -> [[F; USEFUL_ROWS]; CELLS_PER_ROW] {
|
||||
transpose_rows(&deltas_to_matrix_of_rows(deltas, useful_bits))
|
||||
}
|
||||
|
||||
@@ -1,63 +1,89 @@
|
||||
use crate::verifier::{VerifierError, Verify};
|
||||
use json::{array, object, stringify, stringify_pretty, JsonValue};
|
||||
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
|
||||
use std::env::temp_dir;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Output};
|
||||
use uuid::Uuid;
|
||||
use super::utils::{bigint_to_f, deltas_to_matrices};
|
||||
use super::{Curve, CHUNK_SIZE, USEFUL_BITS};
|
||||
use crate::verifier::{VerificationInput, VerifierError, Verify};
|
||||
use halo2_proofs::plonk;
|
||||
use halo2_proofs::plonk::SingleVerifier;
|
||||
use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::transcript::Blake2bRead;
|
||||
use halo2_proofs::transcript::Challenge255;
|
||||
use instant::Instant;
|
||||
use pasta_curves::pallas::Base as F;
|
||||
use pasta_curves::EqAffine;
|
||||
|
||||
pub struct Verifier {}
|
||||
/// halo2's native [halo2::VerifyingKey] can't be used without params, so we wrap
|
||||
/// them in one struct.
|
||||
#[derive(Clone)]
|
||||
pub struct VK {
|
||||
pub key: VerifyingKey<EqAffine>,
|
||||
pub params: Params<EqAffine>,
|
||||
}
|
||||
|
||||
pub struct Verifier {
|
||||
verification_key: VK,
|
||||
curve: Curve,
|
||||
}
|
||||
impl Verifier {
|
||||
pub fn new(vk: VK, curve: Curve) -> Self {
|
||||
Self {
|
||||
verification_key: vk,
|
||||
curve,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Verify for Verifier {
|
||||
fn verify(
|
||||
&self,
|
||||
proof: Vec<u8>,
|
||||
deltas: Vec<String>,
|
||||
plaintext_hash: BigUint,
|
||||
labelsum_hash: BigUint,
|
||||
zero_sum: BigUint,
|
||||
) -> Result<bool, VerifierError> {
|
||||
// public.json is a flat array
|
||||
let mut public_json: Vec<String> = Vec::new();
|
||||
public_json.push(plaintext_hash.to_string());
|
||||
public_json.push(labelsum_hash.to_string());
|
||||
public_json.extend::<Vec<String>>(deltas);
|
||||
public_json.push(zero_sum.to_string());
|
||||
let s = stringify(JsonValue::from(public_json.clone()));
|
||||
fn verify(&self, input: VerificationInput) -> Result<bool, VerifierError> {
|
||||
let params = &self.verification_key.params;
|
||||
let vk = &self.verification_key.key;
|
||||
|
||||
// write into temp files and delete the files after verification
|
||||
let mut path1 = temp_dir();
|
||||
let mut path2 = temp_dir();
|
||||
path1.push(format!("public.json.{}", Uuid::new_v4()));
|
||||
path2.push(format!("proof.json.{}", Uuid::new_v4()));
|
||||
fs::write(path1.clone(), s).expect("Unable to write file");
|
||||
fs::write(path2.clone(), proof).expect("Unable to write file");
|
||||
let strategy = SingleVerifier::new(¶ms);
|
||||
let proof = input.proof;
|
||||
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
|
||||
|
||||
let output = Command::new("node")
|
||||
.args([
|
||||
"verify.mjs",
|
||||
path1.to_str().unwrap(),
|
||||
path2.to_str().unwrap(),
|
||||
])
|
||||
.output();
|
||||
fs::remove_file(path1).expect("Unable to remove file");
|
||||
fs::remove_file(path2).expect("Unable to remove file");
|
||||
// convert deltas into a matrix which halo2 expects
|
||||
let (_, deltas_as_columns) = deltas_to_matrices(&input.deltas, self.useful_bits());
|
||||
|
||||
check_output(&output)?;
|
||||
if !output.unwrap().status.success() {
|
||||
return Ok(false);
|
||||
let mut all_inputs: Vec<&[F]> = deltas_as_columns.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
// add another column with public inputs
|
||||
let tmp = &[
|
||||
bigint_to_f(&input.plaintext_hash),
|
||||
bigint_to_f(&input.label_sum_hash),
|
||||
bigint_to_f(&input.sum_of_zero_labels),
|
||||
];
|
||||
all_inputs.push(tmp);
|
||||
|
||||
let now = Instant::now();
|
||||
// perform the actual verification
|
||||
let res = plonk::verify_proof(
|
||||
params,
|
||||
vk,
|
||||
strategy,
|
||||
&[all_inputs.as_slice()],
|
||||
&mut transcript,
|
||||
);
|
||||
println!("Proof verified [{:?}]", now.elapsed());
|
||||
if res.is_err() {
|
||||
return Err(VerifierError::VerificationFailed);
|
||||
} else {
|
||||
Ok(true)
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
fn check_output(output: &Result<Output, std::io::Error>) -> Result<(), VerifierError> {
|
||||
if output.is_err() {
|
||||
return Err(VerifierError::SnarkjsError);
|
||||
fn field_size(&self) -> usize {
|
||||
match self.curve {
|
||||
Curve::PASTA => 255,
|
||||
Curve::BN254 => 254,
|
||||
_ => panic!("a new curve was added. Add its field size here."),
|
||||
}
|
||||
}
|
||||
if !output.as_ref().unwrap().status.success() {
|
||||
return Err(VerifierError::SnarkjsError);
|
||||
|
||||
fn useful_bits(&self) -> usize {
|
||||
USEFUL_BITS
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
CHUNK_SIZE
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
125
src/label.rs
125
src/label.rs
@@ -1,85 +1,86 @@
|
||||
use super::boolvec_to_u8vec;
|
||||
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
|
||||
use rand::SeedableRng;
|
||||
use rand::{thread_rng, Rng};
|
||||
use super::utils::bits_to_bigint;
|
||||
use num::BigUint;
|
||||
use rand::RngCore;
|
||||
use rand::{thread_rng, Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
|
||||
// The PRG for generating arithmetic labels
|
||||
/// The PRG for generating arithmetic labels.
|
||||
type Prg = ChaCha20Rng;
|
||||
/// The seed from which to generate the arithmetic labels.
|
||||
pub type Seed = [u8; 32];
|
||||
// The arithmetic label
|
||||
// The arithmetic label.
|
||||
type Label = BigUint;
|
||||
/// A pair of labels: the first one encodes the value 0, the second one encodes
|
||||
/// the value 1.
|
||||
pub type LabelPair = [Label; 2];
|
||||
|
||||
/// typestates are used to prevent generate() from being called multiple times
|
||||
/// on the same instance of LabelSeed
|
||||
pub trait State {}
|
||||
|
||||
pub struct Generate {
|
||||
seed: Seed,
|
||||
}
|
||||
pub struct Finished {}
|
||||
|
||||
impl State for Generate {}
|
||||
impl State for Finished {}
|
||||
|
||||
pub struct LabelGenerator<S = Generate>
|
||||
where
|
||||
S: State,
|
||||
{
|
||||
state: S,
|
||||
}
|
||||
pub struct LabelGenerator {}
|
||||
|
||||
impl LabelGenerator {
|
||||
pub fn new() -> LabelGenerator<Generate> {
|
||||
LabelGenerator {
|
||||
state: Generate {
|
||||
seed: thread_rng().gen::<Seed>(),
|
||||
},
|
||||
}
|
||||
/// Generates a seed and then generates `count` arithmetic label pairs
|
||||
/// of bitsize `label_size` from that seed. Returns the labels and the seed.
|
||||
pub fn generate(count: usize, label_size: usize) -> (Vec<LabelPair>, Seed) {
|
||||
let seed = thread_rng().gen::<Seed>();
|
||||
let pairs = LabelGenerator::generate_from_seed(count, label_size, seed);
|
||||
(pairs, seed)
|
||||
}
|
||||
|
||||
pub fn new_from_seed(seed: Seed) -> LabelGenerator<Generate> {
|
||||
LabelGenerator {
|
||||
state: Generate { seed },
|
||||
}
|
||||
// Generates `count` arithmetic label pairs of bitsize `label_size` from a
|
||||
// seed and returns the labels.
|
||||
pub fn generate_from_seed(count: usize, label_size: usize, seed: Seed) -> Vec<LabelPair> {
|
||||
let prg = Prg::from_seed(seed);
|
||||
LabelGenerator::generate_from_prg(count, label_size, Box::new(prg))
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGenerator<Generate> {
|
||||
/// Generates `count` arithmetic label pairs of size `label_size`. Returns
|
||||
/// the generated label pairs and the seed.
|
||||
pub fn generate(
|
||||
self,
|
||||
/// Generates `count` arithmetic label pairs of bitsize `label_size` using a PRG.
|
||||
/// Returns the generated label pairs.
|
||||
fn generate_from_prg(
|
||||
count: usize,
|
||||
label_size: usize,
|
||||
) -> (Vec<LabelPair>, Seed, LabelGenerator<Finished>) {
|
||||
// To keep the handling simple, we want to avoid a negative delta, that's why
|
||||
// W_0 and delta must be (label_size - 1)-bit values and W_1 will be
|
||||
// set to W_0 + delta
|
||||
let mut prg = Prg::from_seed(self.state.seed);
|
||||
|
||||
let label_pairs: Vec<LabelPair> = (0..count)
|
||||
mut prg: Box<dyn RngCore>,
|
||||
) -> Vec<LabelPair> {
|
||||
(0..count)
|
||||
.map(|_| {
|
||||
let zero_label: Vec<bool> = core::iter::repeat_with(|| prg.gen::<bool>())
|
||||
.take(label_size - 1)
|
||||
.collect();
|
||||
let zero_label = BigUint::from_bytes_be(&boolvec_to_u8vec(&zero_label));
|
||||
// To keep the handling simple, we want to avoid a negative delta, that's why
|
||||
// W_0 and delta must be (label_size - 1)-bit values and W_1 will be
|
||||
// set to W_0 + delta
|
||||
let zero_label = bits_to_bigint(
|
||||
&core::iter::repeat_with(|| prg.gen::<bool>())
|
||||
.take(label_size - 1)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let delta: Vec<bool> = core::iter::repeat_with(|| prg.gen::<bool>())
|
||||
.take(label_size - 1)
|
||||
.collect();
|
||||
let delta = BigUint::from_bytes_be(&boolvec_to_u8vec(&delta));
|
||||
let delta = bits_to_bigint(
|
||||
&core::iter::repeat_with(|| prg.gen::<bool>())
|
||||
.take(label_size - 1)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let one_label = zero_label.clone() + delta.clone();
|
||||
[zero_label, one_label]
|
||||
})
|
||||
.collect();
|
||||
|
||||
(
|
||||
label_pairs,
|
||||
self.state.seed,
|
||||
LabelGenerator { state: Finished {} },
|
||||
)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::LabelGenerator;
|
||||
use num::BigUint;
|
||||
use rand::rngs::mock::StepRng;
|
||||
|
||||
#[test]
|
||||
fn test_label_generator() {
|
||||
// PRG which always returns bit 1
|
||||
let prg = StepRng::new(u64::MAX, 0);
|
||||
|
||||
// zero_label and delta should be 511 (bit 1 repeated 9 times), one_label
|
||||
// should be 511+511=1022
|
||||
let result = LabelGenerator::generate_from_prg(10, 10, Box::new(prg));
|
||||
let expected = (0..10)
|
||||
.map(|_| [BigUint::from(511u128), BigUint::from(1022u128)])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(expected, result);
|
||||
}
|
||||
}
|
||||
|
||||
289
src/lib.rs
289
src/lib.rs
@@ -1,128 +1,72 @@
|
||||
use aes::{Aes128, NewBlockCipher};
|
||||
use cipher::{consts::U16, generic_array::GenericArray, BlockCipher, BlockEncrypt};
|
||||
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
|
||||
use poseidon::Poseidon;
|
||||
use rand::{thread_rng, Rng};
|
||||
use sha2::{Digest, Sha256};
|
||||
use num::BigUint;
|
||||
|
||||
pub mod halo2_backend;
|
||||
pub mod label;
|
||||
pub mod onetimesetup;
|
||||
pub mod poseidon;
|
||||
pub mod prover;
|
||||
//pub mod provernode;
|
||||
pub mod snarkjs_backend;
|
||||
pub mod utils;
|
||||
pub mod verifier;
|
||||
pub mod verifiernode;
|
||||
|
||||
// bn254 prime 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
// in decimal 21888242871839275222246405745257275088548364400416034343698204186575808495617
|
||||
const BN254_PRIME: &str =
|
||||
"21888242871839275222246405745257275088548364400416034343698204186575808495617";
|
||||
|
||||
/// How many field elements our Poseidon hash consumes for one permutation.
|
||||
const POSEIDON_RATE: usize = 16;
|
||||
/// How many permutations our circom circuit supports. One permutation consumes
|
||||
/// POSEIDON_WIDTH field elements.
|
||||
const PERMUTATION_COUNT: usize = 1;
|
||||
/// The bitsize of an arithmetic label. MUST be > 40 to give statistical
|
||||
/// security against the Prover guessing the label. For a 254-bit field,
|
||||
/// the bitsize > 96 would require 2 field elements for the
|
||||
/// salted labelsum instead of 1.
|
||||
const ARITHMETIC_LABEL_SIZE: usize = 96;
|
||||
|
||||
/// The maximum size (in bits) of one chunk of plaintext that we support
|
||||
/// for 254-bit fields. Calculated as 2^{Field_size - 1 - 128 - 96}, where
|
||||
/// 128 is the size of salt and 96 is the size of the arithmetic label.
|
||||
const MAX_CHUNK_SIZE: usize = 1 << 29;
|
||||
/// The maximum supported size (in bits) of one [Chunk] of plaintext.
|
||||
/// Should not exceed 2^{ [prover::Prove::useful_bits] - [prover::Prove::salt_size]
|
||||
/// - [ARITHMETIC_LABEL_SIZE] }.
|
||||
/// 2^20 should suffice for most use cases.
|
||||
const MAX_CHUNK_SIZE: usize = 1 << 20;
|
||||
|
||||
pub fn random_bigint(bitsize: usize) -> BigUint {
|
||||
assert!(bitsize <= 128);
|
||||
let r: [u8; 16] = thread_rng().gen();
|
||||
// take only those bits which we need
|
||||
BigUint::from_bytes_be(&boolvec_to_u8vec(&u8vec_to_boolvec(&r)[0..bitsize]))
|
||||
}
|
||||
/// The maximum supported amount of plaintext [Chunk]s ( which equals to the
|
||||
/// amount of zk proofs). Having too many zk proofs may be a DOS vector
|
||||
/// against the Notary who is the verifier of zk proofs.
|
||||
const MAX_CHUNK_COUNT: usize = 128;
|
||||
|
||||
#[inline]
|
||||
pub fn u8vec_to_boolvec(v: &[u8]) -> Vec<bool> {
|
||||
let mut bv = Vec::with_capacity(v.len() * 8);
|
||||
for byte in v.iter() {
|
||||
for i in 0..8 {
|
||||
bv.push(((byte >> (7 - i)) & 1) != 0);
|
||||
}
|
||||
}
|
||||
bv
|
||||
}
|
||||
/// The decoded output labels of the garbled circuit. In other words, this is
|
||||
/// the plaintext output resulting from the evaluation of a garbled circuit.
|
||||
type Plaintext = Vec<u8>;
|
||||
|
||||
// Convert bits into bytes. The bits will be left-padded with zeroes to the
|
||||
// multiple of 8.
|
||||
#[inline]
|
||||
pub fn boolvec_to_u8vec(bv: &[bool]) -> Vec<u8> {
|
||||
let rem = bv.len() % 8;
|
||||
let first_byte_bitsize = if rem == 0 { 8 } else { rem };
|
||||
let offset = if rem == 0 { 0 } else { 1 };
|
||||
let mut v = vec![0u8; bv.len() / 8 + offset];
|
||||
// implicitely left-pad the first byte with zeroes
|
||||
for (i, b) in bv[0..first_byte_bitsize].iter().enumerate() {
|
||||
v[i / 8] |= (*b as u8) << (first_byte_bitsize - 1 - i);
|
||||
}
|
||||
for (i, b) in bv[first_byte_bitsize..].iter().enumerate() {
|
||||
v[1 + i / 8] |= (*b as u8) << (7 - (i % 8));
|
||||
}
|
||||
v
|
||||
}
|
||||
/// A chunk of [Plaintext]. The amount of vec elements equals
|
||||
/// [Prove::poseidon_rate] * [Prove::permutation_count]. Each vec element
|
||||
/// is an "Elliptic curve field element" into which [Prove::useful_bits] bits
|
||||
/// of [Plaintext] is packed.
|
||||
/// The chunk does NOT contain the [Salt].
|
||||
type Chunk = Vec<BigUint>;
|
||||
|
||||
pub fn sha256(data: &[u8]) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().into()
|
||||
}
|
||||
/// Before hashing a [Chunk], it is salted by shifting its last element to the
|
||||
/// left by [Prove::salt_size] and placing the salt into the low bits.
|
||||
/// This same salt is also used to salt the sum of all the labels corresponding
|
||||
/// to the [Chunk].
|
||||
/// Without the salt, a hash of plaintext with low entropy could be brute-forced.
|
||||
type Salt = BigUint;
|
||||
|
||||
/// Encrypts each arithmetic label using a corresponding binary label as a key
|
||||
/// and returns ciphertexts in an order based on binary label's pointer bit (LSB).
|
||||
pub fn encrypt_arithmetic_labels(
|
||||
alabels: &Vec<[BigUint; 2]>,
|
||||
blabels: &Vec<[u128; 2]>,
|
||||
) -> Vec<[Vec<u8>; 2]> {
|
||||
assert!(alabels.len() == blabels.len());
|
||||
/// A Poseidon hash digest of a [Salt]ed [Chunk]. This is an EC field element.
|
||||
type PlaintextHash = BigUint;
|
||||
|
||||
blabels
|
||||
.iter()
|
||||
.zip(alabels)
|
||||
.map(|(bin_pair, arithm_pair)| {
|
||||
let zero_key = Aes128::new_from_slice(&bin_pair[0].to_be_bytes()).unwrap();
|
||||
let one_key = Aes128::new_from_slice(&bin_pair[1].to_be_bytes()).unwrap();
|
||||
let mut label0 = [0u8; 16];
|
||||
let mut label1 = [0u8; 16];
|
||||
let ap0 = arithm_pair[0].to_bytes_be();
|
||||
let ap1 = arithm_pair[1].to_bytes_be();
|
||||
// pad with zeroes on the left
|
||||
label0[16 - ap0.len()..].copy_from_slice(&ap0);
|
||||
label1[16 - ap1.len()..].copy_from_slice(&ap1);
|
||||
let mut label0: GenericArray<u8, U16> = GenericArray::from(label0);
|
||||
let mut label1: GenericArray<u8, U16> = GenericArray::from(label1);
|
||||
zero_key.encrypt_block(&mut label0);
|
||||
one_key.encrypt_block(&mut label1);
|
||||
// ciphertext 0 and ciphertext 1
|
||||
let ct0 = label0.to_vec();
|
||||
let ct1 = label1.to_vec();
|
||||
// place ar. labels based on the point and permute bit of bin. label 0
|
||||
if (bin_pair[0] & 1) == 0 {
|
||||
[ct0, ct1]
|
||||
} else {
|
||||
[ct1, ct0]
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
/// A Poseidon hash digest of a [Salt]ed arithmetic sum of arithmetic labels
|
||||
/// corresponding to the [Chunk]. This is an EC field element.
|
||||
type LabelsumHash = BigUint;
|
||||
|
||||
/// An arithmetic sum of all "zero" arithmetic labels ( those are the labels
|
||||
/// which encode the bit value 0) corresponding to one [Chunk].
|
||||
type ZeroSum = BigUint;
|
||||
|
||||
/// An arithmetic difference between the arithmetic label "one" and the
|
||||
/// arithmetic label "zero".
|
||||
type Delta = BigUint;
|
||||
|
||||
/// A serialized proof proving that a Poseidon hash is the result of hashing a
|
||||
/// salted [Chunk].
|
||||
type Proof = Vec<u8>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::onetimesetup::OneTimeSetup;
|
||||
|
||||
use super::*;
|
||||
use num::{BigUint, FromPrimitive};
|
||||
use prover::LabelsumProver;
|
||||
use rand::{thread_rng, Rng, RngCore};
|
||||
use rand::{thread_rng, Rng};
|
||||
use verifier::LabelsumVerifier;
|
||||
|
||||
/// Unzips a slice of pairs, returning items corresponding to choice
|
||||
@@ -139,98 +83,85 @@ mod tests {
|
||||
std::any::type_name::<T>()
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn e2e_test() {
|
||||
// let prime = String::from(BN254_PRIME).parse::<BigUint>().unwrap();
|
||||
// let mut rng = thread_rng();
|
||||
pub mod fixtures {
|
||||
use super::utils::*;
|
||||
use super::*;
|
||||
use crate::prover::Prove;
|
||||
use crate::verifier::Verify;
|
||||
|
||||
// // OneTimeSetup is a no-op if the setup has been run before
|
||||
// let ots = OneTimeSetup::new();
|
||||
// ots.setup().unwrap();
|
||||
/// Accepts a concrete Prover and Verifier and runs the whole labelsum
|
||||
/// protocol end-to-end.
|
||||
pub fn e2e_test(prover: Box<dyn Prove>, verifier: Box<dyn Verify>) {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// // Poseidon need to be instantiated once and then passed to all instances
|
||||
// // of Prover
|
||||
// let poseidon = Poseidon::new();
|
||||
// generate random plaintext of random size up to 2000 bytes
|
||||
let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
|
||||
.take(thread_rng().gen_range(0..1000))
|
||||
.collect();
|
||||
|
||||
// // The Prover should have received the proving key (before the labelsum
|
||||
// // protocol starts) like this:
|
||||
// let proving_key = ots.get_proving_key().unwrap();
|
||||
// Normally, the Prover is expected to obtain her binary labels by
|
||||
// evaluating the garbled circuit.
|
||||
// To keep this test simple, we don't evaluate the gc, but we generate
|
||||
// all labels of the Verifier and give the Prover her active labels.
|
||||
let bit_size = plaintext.len() * 8;
|
||||
let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
|
||||
let mut delta: u128 = rng.gen();
|
||||
// set the last bit
|
||||
delta |= 1;
|
||||
for _ in 0..bit_size {
|
||||
let label_zero: u128 = rng.gen();
|
||||
all_binary_labels.push([label_zero, label_zero ^ delta]);
|
||||
}
|
||||
let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
|
||||
|
||||
// // generate random plaintext of random size up to 2000 bytes
|
||||
// let plaintext: Vec<u8> = core::iter::repeat_with(|| rng.gen::<u8>())
|
||||
// .take(thread_rng().gen_range(0..2000))
|
||||
// .collect();
|
||||
let verifier = LabelsumVerifier::new(all_binary_labels.clone(), verifier);
|
||||
|
||||
// // Normally, the Prover is expected to obtain her binary labels by
|
||||
// // evaluating the garbled circuit.
|
||||
// // To keep this test simple, we don't evaluate the gc, but we generate
|
||||
// // all labels of the Verifier and give the Prover her active labels.
|
||||
// let bit_size = plaintext.len() * 8;
|
||||
// let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size);
|
||||
// let mut delta: u128 = rng.gen();
|
||||
// // set the last bit
|
||||
// delta |= 1;
|
||||
// for _ in 0..bit_size {
|
||||
// let label_zero: u128 = rng.gen();
|
||||
// all_binary_labels.push([label_zero, label_zero ^ delta]);
|
||||
// }
|
||||
// let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext));
|
||||
let verifier = verifier.setup().unwrap();
|
||||
|
||||
// let verifier = LabelsumVerifier::new(
|
||||
// all_binary_labels.clone(),
|
||||
// Box::new(verifiernode::VerifierNode {}),
|
||||
// );
|
||||
let prover = LabelsumProver::new(plaintext, prover);
|
||||
|
||||
// let verifier = verifier.setup().unwrap();
|
||||
// Perform setup
|
||||
let prover = prover.setup().unwrap();
|
||||
|
||||
// let prover = LabelsumProver::new(
|
||||
// proving_key,
|
||||
// prime,
|
||||
// plaintext,
|
||||
// poseidon,
|
||||
// Box::new(provernode::ProverNode {}),
|
||||
// );
|
||||
// Commitment to the plaintext is sent to the Notary
|
||||
let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
|
||||
|
||||
// // Perform setup
|
||||
// let prover = prover.setup().unwrap();
|
||||
// Notary sends back encrypted arithm. labels.
|
||||
let (ciphertexts, verifier) =
|
||||
verifier.receive_plaintext_hashes(plaintext_hash).unwrap();
|
||||
|
||||
// // Commitment to the plaintext is sent to the Notary
|
||||
// let (plaintext_hash, prover) = prover.plaintext_commitment().unwrap();
|
||||
// Hash commitment to the label_sum is sent to the Notary
|
||||
let (label_sum_hashes, prover) = prover
|
||||
.labelsum_commitment(ciphertexts, &prover_labels)
|
||||
.unwrap();
|
||||
|
||||
// // Notary sends back encrypted arithm. labels.
|
||||
// let (cipheretexts, verifier) = verifier.receive_plaintext_hashes(plaintext_hash);
|
||||
// Notary sends the arithmetic label seed
|
||||
let (seed, verifier) = verifier.receive_labelsum_hashes(label_sum_hashes).unwrap();
|
||||
|
||||
// // Hash commitment to the label_sum is sent to the Notary
|
||||
// let (label_sum_hashes, prover) = prover
|
||||
// .labelsum_commitment(cipheretexts, &prover_labels)
|
||||
// .unwrap();
|
||||
// At this point the following happens in the `committed GC` protocol:
|
||||
// - the Notary reveals the GC seed
|
||||
// - the User checks that the GC was created from that seed
|
||||
// - the User checks that her active output labels correspond to the
|
||||
// output labels derived from the seed
|
||||
// - we are called with the result of the check and (if successful)
|
||||
// with all the output labels
|
||||
|
||||
// // Notary sends the arithmetic label seed
|
||||
// let (seed, verifier) = verifier.receive_labelsum_hashes(label_sum_hashes);
|
||||
let prover = prover
|
||||
.binary_labels_authenticated(true, Some(all_binary_labels))
|
||||
.unwrap();
|
||||
|
||||
// // At this point the following happens in the `committed GC` protocol:
|
||||
// // - the Notary reveals the GC seed
|
||||
// // - the User checks that the GC was created from that seed
|
||||
// // - the User checks that her active output labels correspond to the
|
||||
// // output labels derived from the seed
|
||||
// // - we are called with the result of the check and (if successful)
|
||||
// // with all the output labels
|
||||
// Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
|
||||
let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
|
||||
|
||||
// let prover = prover
|
||||
// .binary_labels_authenticated(true, Some(all_binary_labels))
|
||||
// .unwrap();
|
||||
// Prover generates the proof
|
||||
let (proofs, prover) = prover.create_zk_proofs().unwrap();
|
||||
|
||||
// // Prover checks the integrity of the arithmetic labels and generates zero_sums and deltas
|
||||
// let prover = prover.authenticate_arithmetic_labels(seed).unwrap();
|
||||
|
||||
// // Prover generates the proof
|
||||
// let (proofs, prover) = prover.create_zk_proof().unwrap();
|
||||
|
||||
// // Notary verifies the proof
|
||||
// let verifier = verifier.verify_many(proofs).unwrap();
|
||||
// assert_eq!(
|
||||
// type_of(&verifier),
|
||||
// "labelsum::verifier::LabelsumVerifier<labelsum::verifier::VerificationSuccessfull>"
|
||||
// );
|
||||
// }
|
||||
// Notary verifies the proof
|
||||
let verifier = verifier.verify_many(proofs).unwrap();
|
||||
assert_eq!(
|
||||
type_of(&verifier),
|
||||
"labelsum::verifier::LabelsumVerifier<labelsum::verifier::VerificationSuccessfull>"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1193
src/prover.rs
1193
src/prover.rs
File diff suppressed because it is too large
Load Diff
@@ -1 +1,32 @@
|
||||
pub mod onetimesetup;
|
||||
pub mod poseidon;
|
||||
pub mod provernode;
|
||||
pub mod verifiernode;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::onetimesetup::OneTimeSetup;
|
||||
use super::provernode::Prover;
|
||||
use super::verifiernode::Verifier;
|
||||
use crate::tests::fixtures::e2e_test;
|
||||
|
||||
#[test]
|
||||
fn snarkjs_e2e_test() {
|
||||
let prover_ots = OneTimeSetup::new();
|
||||
let verifier_ots = OneTimeSetup::new();
|
||||
|
||||
// The Prover should have generated the proving key (before the labelsum
|
||||
// protocol starts) like this:
|
||||
prover_ots.setup().unwrap();
|
||||
let proving_key = prover_ots.get_proving_key().unwrap();
|
||||
|
||||
// The Verifier should have generated the verifying key (before the labelsum
|
||||
// protocol starts) like this:
|
||||
verifier_ots.setup().unwrap();
|
||||
let verification_key = verifier_ots.get_verification_key().unwrap();
|
||||
|
||||
let prover = Box::new(Prover::new(proving_key));
|
||||
let verifier = Box::new(Verifier::new(verification_key));
|
||||
e2e_test(prover, verifier);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,16 +41,16 @@ impl OneTimeSetup {
|
||||
|
||||
pub fn setup(&self) -> Result<(), Error> {
|
||||
// check if files which we ship are present
|
||||
if !Path::new("powersOfTau28_hez_final_14.ptau").exists()
|
||||
|| !Path::new("circuit.r1cs").exists()
|
||||
if !Path::new("circom/powersOfTau28_hez_final_14.ptau").exists()
|
||||
|| !Path::new("circom/circuit.r1cs").exists()
|
||||
{
|
||||
return Err(Error::FileDoesNotExist);
|
||||
}
|
||||
// check if any of the files hasn't been generated. If so, regenerate
|
||||
// all files.
|
||||
if !Path::new("circuit_0000.zkey").exists()
|
||||
|| !Path::new("circuit_final.zkey.notary").exists()
|
||||
|| !Path::new("verification_key.json").exists()
|
||||
if !Path::new("circom/circuit_0000.zkey").exists()
|
||||
|| !Path::new("circom/circuit_final.zkey.notary").exists()
|
||||
|| !Path::new("circom/verification_key.json").exists()
|
||||
{
|
||||
let entropy = self.generate_entropy();
|
||||
//return self.regenerate1(entropy);
|
||||
@@ -62,12 +62,22 @@ impl OneTimeSetup {
|
||||
|
||||
// Returns the already existing proving key
|
||||
pub fn get_proving_key(&self) -> Result<Vec<u8>, Error> {
|
||||
let path = Path::new("circuit_final.zkey.notary");
|
||||
let path = Path::new("circom/circuit_final.zkey.notary");
|
||||
if !path.exists() {
|
||||
return Err(Error::FileDoesNotExist);
|
||||
}
|
||||
let proof = fs::read(path.clone()).unwrap();
|
||||
Ok(proof)
|
||||
let key = fs::read(path.clone()).unwrap();
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
// Returns the already existing verification key
|
||||
pub fn get_verification_key(&self) -> Result<Vec<u8>, Error> {
|
||||
let path = Path::new("circom/verification_key.json");
|
||||
if !path.exists() {
|
||||
return Err(Error::FileDoesNotExist);
|
||||
}
|
||||
let key = fs::read(path.clone()).unwrap();
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
// this will work only if snarkjs is in the PATH
|
||||
@@ -76,9 +86,9 @@ impl OneTimeSetup {
|
||||
.args([
|
||||
"groth16",
|
||||
"setup",
|
||||
"circuit.r1cs",
|
||||
"powersOfTau28_hez_final_14.ptau",
|
||||
"circuit_0000.zkey",
|
||||
"circom/circuit.r1cs",
|
||||
"circom/powersOfTau28_hez_final_14.ptau",
|
||||
"circom/circuit_0000.zkey",
|
||||
])
|
||||
.output();
|
||||
self.check_output(output)?;
|
||||
@@ -87,8 +97,8 @@ impl OneTimeSetup {
|
||||
.args([
|
||||
"zkey",
|
||||
"contribute",
|
||||
"circuit_0000.zkey",
|
||||
"circuit_final.zkey.notary",
|
||||
"circom/circuit_0000.zkey",
|
||||
"circom/circuit_final.zkey.notary",
|
||||
&(String::from("-e=\"") + &entropy + &String::from("\"")),
|
||||
])
|
||||
.output();
|
||||
@@ -99,8 +109,8 @@ impl OneTimeSetup {
|
||||
"zkey",
|
||||
"export",
|
||||
"verificationkey",
|
||||
"circuit_final.zkey.notary",
|
||||
"verification_key.json",
|
||||
"circom/circuit_final.zkey.notary",
|
||||
"circom/verification_key.json",
|
||||
])
|
||||
.output();
|
||||
self.check_output(output)?;
|
||||
@@ -111,19 +121,9 @@ impl OneTimeSetup {
|
||||
// call a js wrapper which does what regenerate1() above does
|
||||
fn regenerate2(&self, entropy: String) -> Result<(), Error> {
|
||||
let output = Command::new("node")
|
||||
.args(["onetimesetup.mjs", &entropy])
|
||||
.args(["circom/onetimesetup.mjs", &entropy])
|
||||
.output();
|
||||
self.check_output(output)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test() {
|
||||
let mut ots = OneTimeSetup::new();
|
||||
ots.setup().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use ark_bn254::Fr as F;
|
||||
use ark_ff::{One, PrimeField};
|
||||
use ark_sponge::poseidon::{PoseidonConfig, PoseidonSponge};
|
||||
use ark_sponge::CryptographicSponge;
|
||||
use ark_sponge::FieldBasedCryptographicSponge;
|
||||
use ark_sponge::{CryptographicSponge, DuplexSpongeMode};
|
||||
use lazy_static::lazy_static;
|
||||
use num::{BigUint, FromPrimitive, Num, ToPrimitive, Zero};
|
||||
use num::{BigUint, Num};
|
||||
use regex::Regex;
|
||||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
Error1,
|
||||
}
|
||||
/// additive round keys for a specific Poseidon rate
|
||||
/// outer vec length equals to total (partial + full) round count
|
||||
/// inner vec length equals rate + capacity (our Poseidon's capacity is fixed at 1)
|
||||
@@ -40,7 +35,7 @@ pub struct Poseidon {
|
||||
|
||||
impl Poseidon {
|
||||
pub fn new() -> Poseidon {
|
||||
let (arks, mdss) = setup().unwrap();
|
||||
let (arks, mdss) = setup();
|
||||
Poseidon { arks, mdss }
|
||||
}
|
||||
|
||||
@@ -75,8 +70,8 @@ impl Poseidon {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup() -> Result<(Vec<Ark>, Vec<Mds>), Error> {
|
||||
let mut file = File::open("poseidon_constants_old.circom").unwrap();
|
||||
fn setup() -> (Vec<Ark>, Vec<Mds>) {
|
||||
let mut file = File::open("circom/poseidon_constants_old.circom").unwrap();
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents).unwrap();
|
||||
|
||||
@@ -124,13 +119,5 @@ fn setup() -> Result<(Vec<Ark>, Vec<Mds>), Error> {
|
||||
}
|
||||
// we should have consumed all elements
|
||||
assert!(v.len() == offset);
|
||||
Ok((arks, mdss))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stuff() {
|
||||
let p = Poseidon::new();
|
||||
let input = vec![BigUint::from_u8(0).unwrap(); 16];
|
||||
let out = p.hash(&input);
|
||||
print!("{:?} out:", out);
|
||||
(arks, mdss)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,85 @@
|
||||
use super::poseidon::Poseidon;
|
||||
use crate::prover::ProofInput;
|
||||
use crate::prover::{Prove, ProverError};
|
||||
use json::{object, stringify_pretty};
|
||||
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
|
||||
use std::env::temp_dir;
|
||||
use std::fs;
|
||||
use std::process::{Command, Output};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct ProverNode {}
|
||||
pub struct Prover {
|
||||
proving_key: Vec<u8>,
|
||||
poseidon: Poseidon,
|
||||
}
|
||||
|
||||
impl Prover {
|
||||
pub fn new(proving_key: Vec<u8>) -> Self {
|
||||
Self {
|
||||
proving_key,
|
||||
poseidon: Poseidon::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// Creates inputs in the "input.json" format
|
||||
fn create_proof_inputs(&self, input: ProofInput) -> String {
|
||||
// convert each field element of plaintext into a string
|
||||
let plaintext: Vec<String> = input
|
||||
.plaintext
|
||||
.iter()
|
||||
.map(|bigint| bigint.to_string())
|
||||
.collect();
|
||||
|
||||
// convert all deltas to strings
|
||||
let deltas_str: Vec<String> = input.deltas.iter().map(|v| v.to_string()).collect();
|
||||
|
||||
// split deltas into groups corresponding to the field elements
|
||||
// of our Poseidon circuit
|
||||
let deltas_fes: Vec<&[String]> = deltas_str.chunks(self.useful_bits()).collect();
|
||||
|
||||
// prepare input.json
|
||||
let input = object! {
|
||||
plaintext_hash: input.plaintext_hash.to_string(),
|
||||
label_sum_hash: input.label_sum_hash.to_string(),
|
||||
sum_of_zero_labels: input.sum_of_zero_labels.to_string(),
|
||||
plaintext: plaintext,
|
||||
salt: input.salt.to_string(),
|
||||
delta: deltas_fes[0..deltas_fes.len()-1],
|
||||
// last field element's deltas are a separate input
|
||||
delta_last: deltas_fes[deltas_fes.len()-1]
|
||||
};
|
||||
stringify_pretty(input, 4)
|
||||
}
|
||||
}
|
||||
|
||||
impl Prove for Prover {
|
||||
fn useful_bits(&self) -> usize {
|
||||
253
|
||||
}
|
||||
|
||||
fn poseidon_rate(&self) -> usize {
|
||||
16
|
||||
}
|
||||
|
||||
fn permutation_count(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn salt_size(&self) -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
3920 //253*15+125
|
||||
}
|
||||
|
||||
fn hash(&self, inputs: &Vec<BigUint>) -> Result<BigUint, ProverError> {
|
||||
Ok(self.poseidon.hash(inputs))
|
||||
}
|
||||
|
||||
impl Prove for ProverNode {
|
||||
/// Produces a groth16 proof with snarkjs. Input must be a JSON string in the
|
||||
/// "input.json" format which snarkjs expects.
|
||||
fn prove(&self, input: String, proving_key: &Vec<u8>) -> Result<Vec<u8>, ProverError> {
|
||||
fn prove(&self, input: ProofInput) -> Result<Vec<u8>, ProverError> {
|
||||
let mut path1 = temp_dir();
|
||||
let mut path2 = temp_dir();
|
||||
let mut path3 = temp_dir();
|
||||
@@ -18,11 +87,13 @@ impl Prove for ProverNode {
|
||||
path2.push(format!("proving_key.zkey.{}", Uuid::new_v4()));
|
||||
path3.push(format!("proof.json.{}", Uuid::new_v4()));
|
||||
|
||||
let input = self.create_proof_inputs(input);
|
||||
|
||||
fs::write(path1.clone(), input).expect("Unable to write file");
|
||||
fs::write(path2.clone(), proving_key).expect("Unable to write file");
|
||||
fs::write(path2.clone(), self.proving_key.clone()).expect("Unable to write file");
|
||||
let output = Command::new("node")
|
||||
.args([
|
||||
"prove.mjs",
|
||||
"circom/prove.mjs",
|
||||
path1.to_str().unwrap(),
|
||||
path2.to_str().unwrap(),
|
||||
path3.to_str().unwrap(),
|
||||
@@ -51,10 +122,10 @@ impl Prove for ProverNode {
|
||||
|
||||
fn check_output(output: Result<Output, std::io::Error>) -> Result<(), ProverError> {
|
||||
if output.is_err() {
|
||||
return Err(ProverError::SnarkjsError);
|
||||
return Err(ProverError::ProvingBackendError);
|
||||
}
|
||||
if !output.unwrap().status.success() {
|
||||
return Err(ProverError::SnarkjsError);
|
||||
return Err(ProverError::ProvingBackendError);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::verifier::VerificationInput;
|
||||
use crate::verifier::{VerifierError, Verify};
|
||||
use json::{array, object, stringify, stringify_pretty, JsonValue};
|
||||
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
|
||||
@@ -7,23 +8,36 @@ use std::path::Path;
|
||||
use std::process::{Command, Output};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct VerifierNode {}
|
||||
pub struct Verifier {
|
||||
verification_key: Vec<u8>,
|
||||
}
|
||||
impl Verifier {
|
||||
pub fn new(verification_key: Vec<u8>) -> Self {
|
||||
Self { verification_key }
|
||||
}
|
||||
}
|
||||
|
||||
impl Verify for VerifierNode {
|
||||
fn verify(
|
||||
&self,
|
||||
proof: Vec<u8>,
|
||||
deltas: Vec<String>,
|
||||
plaintext_hash: BigUint,
|
||||
labelsum_hash: BigUint,
|
||||
zero_sum: BigUint,
|
||||
) -> Result<bool, VerifierError> {
|
||||
impl Verify for Verifier {
|
||||
fn field_size(&self) -> usize {
|
||||
254
|
||||
}
|
||||
|
||||
fn useful_bits(&self) -> usize {
|
||||
253
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
3920 //253*15+125
|
||||
}
|
||||
|
||||
fn verify(&self, input: VerificationInput) -> Result<bool, VerifierError> {
|
||||
// public.json is a flat array
|
||||
let mut public_json: Vec<String> = Vec::new();
|
||||
public_json.push(plaintext_hash.to_string());
|
||||
public_json.push(labelsum_hash.to_string());
|
||||
public_json.extend::<Vec<String>>(deltas);
|
||||
public_json.push(zero_sum.to_string());
|
||||
public_json.push(input.plaintext_hash.to_string());
|
||||
public_json.push(input.label_sum_hash.to_string());
|
||||
let delta_str: Vec<String> = input.deltas.iter().map(|v| v.to_string()).collect();
|
||||
public_json.extend::<Vec<String>>(delta_str);
|
||||
public_json.push(input.sum_of_zero_labels.to_string());
|
||||
let s = stringify(JsonValue::from(public_json.clone()));
|
||||
|
||||
// write into temp files and delete the files after verification
|
||||
@@ -32,11 +46,11 @@ impl Verify for VerifierNode {
|
||||
path1.push(format!("public.json.{}", Uuid::new_v4()));
|
||||
path2.push(format!("proof.json.{}", Uuid::new_v4()));
|
||||
fs::write(path1.clone(), s).expect("Unable to write file");
|
||||
fs::write(path2.clone(), proof).expect("Unable to write file");
|
||||
fs::write(path2.clone(), input.proof).expect("Unable to write file");
|
||||
|
||||
let output = Command::new("node")
|
||||
.args([
|
||||
"verify.mjs",
|
||||
"circom/verify.mjs",
|
||||
path1.to_str().unwrap(),
|
||||
path2.to_str().unwrap(),
|
||||
])
|
||||
@@ -54,10 +68,10 @@ impl Verify for VerifierNode {
|
||||
|
||||
fn check_output(output: &Result<Output, std::io::Error>) -> Result<(), VerifierError> {
|
||||
if output.is_err() {
|
||||
return Err(VerifierError::SnarkjsError);
|
||||
return Err(VerifierError::VerifyingBackendError);
|
||||
}
|
||||
if !output.as_ref().unwrap().status.success() {
|
||||
return Err(VerifierError::SnarkjsError);
|
||||
return Err(VerifierError::VerifyingBackendError);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
190
src/utils.rs
Normal file
190
src/utils.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use crate::{Delta, ZeroSum};
|
||||
use aes::{Aes128, NewBlockCipher};
|
||||
use ark_ff::BigInt;
|
||||
use cipher::{consts::U16, generic_array::GenericArray, BlockEncrypt};
|
||||
use num::BigUint;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Converts bits in MSB-first order into a `BigUint`
|
||||
pub fn bits_to_bigint(bits: &[bool]) -> BigUint {
|
||||
BigUint::from_bytes_be(&boolvec_to_u8vec(&bits))
|
||||
}
|
||||
#[test]
|
||||
fn test_bits_to_bigint() {
|
||||
let bits = [true, false];
|
||||
assert_eq!(bits_to_bigint(&bits), 2u8.into());
|
||||
}
|
||||
|
||||
/// Converts bits in MSB-first order into BE bytes. The bits will be left-padded
|
||||
/// with zeroes to the nearest multiple of 8.
|
||||
pub fn boolvec_to_u8vec(bv: &[bool]) -> Vec<u8> {
|
||||
let rem = bv.len() % 8;
|
||||
let first_byte_bitsize = if rem == 0 { 8 } else { rem };
|
||||
let offset = if rem == 0 { 0 } else { 1 };
|
||||
let mut v = vec![0u8; bv.len() / 8 + offset];
|
||||
// implicitely left-pad the first byte with zeroes
|
||||
for (i, b) in bv[0..first_byte_bitsize].iter().enumerate() {
|
||||
v[i / 8] |= (*b as u8) << (first_byte_bitsize - 1 - i);
|
||||
}
|
||||
for (i, b) in bv[first_byte_bitsize..].iter().enumerate() {
|
||||
v[1 + i / 8] |= (*b as u8) << (7 - (i % 8));
|
||||
}
|
||||
v
|
||||
}
|
||||
#[test]
|
||||
fn test_boolvec_to_u8vec() {
|
||||
let bits = [true, false];
|
||||
assert_eq!(boolvec_to_u8vec(&bits), [2]);
|
||||
|
||||
let bits = [true, false, false, false, false, false, false, true, true];
|
||||
assert_eq!(boolvec_to_u8vec(&bits), [1, 3]);
|
||||
}
|
||||
|
||||
/// Converts BE bytes into bits in MSB-first order, left-padding with zeroes
|
||||
/// to the nearest multiple of 8.
|
||||
pub fn u8vec_to_boolvec(v: &[u8]) -> Vec<bool> {
|
||||
let mut bv = Vec::with_capacity(v.len() * 8);
|
||||
for byte in v.iter() {
|
||||
for i in 0..8 {
|
||||
bv.push(((byte >> (7 - i)) & 1) != 0);
|
||||
}
|
||||
}
|
||||
bv
|
||||
}
|
||||
#[test]
|
||||
fn test_u8vec_to_boolvec() {
|
||||
let bytes = [1];
|
||||
assert_eq!(
|
||||
u8vec_to_boolvec(&bytes),
|
||||
[false, false, false, false, false, false, false, true]
|
||||
);
|
||||
|
||||
let bytes = [255, 2];
|
||||
assert_eq!(
|
||||
u8vec_to_boolvec(&bytes),
|
||||
[
|
||||
true, true, true, true, true, true, true, true, false, false, false, false, false,
|
||||
false, true, false
|
||||
]
|
||||
);
|
||||
|
||||
// convert to bits and back to bytes
|
||||
let bignum: BigUint = 3898219876643u128.into();
|
||||
let bits = u8vec_to_boolvec(&bignum.to_bytes_be());
|
||||
let bytes = boolvec_to_u8vec(&bits);
|
||||
assert_eq!(bignum, BigUint::from_bytes_be(&bytes));
|
||||
}
|
||||
|
||||
/// Returns sha256 hash digest
|
||||
pub fn sha256(data: &[u8]) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Encrypts each arithmetic label using a corresponding binary label as a key
|
||||
/// and returns ciphertexts in an order based on binary label's pointer bit (LSB).
|
||||
pub fn encrypt_arithmetic_labels(
|
||||
alabels: &Vec<[BigUint; 2]>,
|
||||
blabels: &Vec<[u128; 2]>,
|
||||
) -> Result<Vec<[[u8; 16]; 2]>, String> {
|
||||
if alabels.len() > blabels.len() {
|
||||
return Err("error".to_string());
|
||||
}
|
||||
|
||||
Ok(blabels
|
||||
.iter()
|
||||
.zip(alabels)
|
||||
.map(|(bin_pair, arithm_pair)| {
|
||||
// safe to unwrap() since to_be_bytes() always returns exactly 16
|
||||
// bytes for u128
|
||||
let zero_key = Aes128::new_from_slice(&bin_pair[0].to_be_bytes()).unwrap();
|
||||
let one_key = Aes128::new_from_slice(&bin_pair[1].to_be_bytes()).unwrap();
|
||||
|
||||
let mut label0 = [0u8; 16];
|
||||
let mut label1 = [0u8; 16];
|
||||
let ap0 = arithm_pair[0].to_bytes_be();
|
||||
let ap1 = arithm_pair[1].to_bytes_be();
|
||||
// pad with zeroes on the left
|
||||
label0[16 - ap0.len()..].copy_from_slice(&ap0);
|
||||
label1[16 - ap1.len()..].copy_from_slice(&ap1);
|
||||
|
||||
let mut label0: GenericArray<u8, U16> = GenericArray::from(label0);
|
||||
let mut label1: GenericArray<u8, U16> = GenericArray::from(label1);
|
||||
zero_key.encrypt_block(&mut label0);
|
||||
one_key.encrypt_block(&mut label1);
|
||||
// place encrypted arithmetic labels based on the pointer bit of
|
||||
// binary label 0
|
||||
if (bin_pair[0] & 1) == 0 {
|
||||
[label0.into(), label1.into()]
|
||||
} else {
|
||||
[label1.into(), label0.into()]
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
#[test]
|
||||
fn test_encrypt_arithmetic_labels() {
|
||||
let alabels: [BigUint; 2] = [3u8.into(), 4u8.into()];
|
||||
let blabels = [0u128, 1u128];
|
||||
let res = encrypt_arithmetic_labels(&vec![alabels], &vec![blabels]).unwrap();
|
||||
let flat = res[0]
|
||||
.into_iter()
|
||||
.map(|ct| ct)
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
// expected value generated with python3:
|
||||
// from Crypto.Cipher import AES
|
||||
// k0 = AES.new((0).to_bytes(16, 'big'), AES.MODE_ECB)
|
||||
// ct0 = k0.encrypt((3).to_bytes(16, 'big')).hex()
|
||||
// k1 = AES.new((1).to_bytes(16, 'big'), AES.MODE_ECB)
|
||||
// ct1 = k1.encrypt((4).to_bytes(16, 'big')).hex()
|
||||
// print(ct0+ct1)
|
||||
let expected = "f795aaab494b5923f7fd89ff948bc1e0382fa171550467b34c54c58b9d3cfd24";
|
||||
assert_eq!(hex::encode(&flat), expected);
|
||||
}
|
||||
|
||||
/// Returns the sum of all zero labels and deltas for each label pair.
|
||||
pub fn compute_zero_sum_and_deltas(
|
||||
arithmetic_label_pairs: &[[BigUint; 2]],
|
||||
) -> (ZeroSum, Vec<Delta>) {
|
||||
let mut deltas: Vec<Delta> = Vec::with_capacity(arithmetic_label_pairs.len());
|
||||
let mut zero_sum: ZeroSum = 0u8.into();
|
||||
for label_pair in arithmetic_label_pairs {
|
||||
// calculate the sum of all zero labels
|
||||
zero_sum += label_pair[0].clone();
|
||||
// put deltas from into one vec
|
||||
deltas.push(label_pair[1].clone() - label_pair[0].clone());
|
||||
}
|
||||
(zero_sum, deltas)
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Tests compute_zero_sum_and_deltas()
|
||||
fn test_compute_zero_sum_and_deltas() {
|
||||
let labels: [[BigUint; 2]; 2] = [[1u8.into(), 2u8.into()], [3u8.into(), 4u8.into()]];
|
||||
let (z, d) = compute_zero_sum_and_deltas(&labels);
|
||||
|
||||
assert_eq!(z, 4u8.into());
|
||||
assert_eq!(d, [1u8.into(), 1u8.into()]);
|
||||
}
|
||||
|
||||
/// Make sure that the `BigUint`s bitsize is not larger than `bitsize`
|
||||
pub fn sanitize_biguint(input: &BigUint, bitsize: usize) -> Result<(), String> {
|
||||
if (input.bits() as usize) > bitsize {
|
||||
return Err("error".to_string());
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
/// Tests sanitize_biguint()
|
||||
fn test_sanitize_biguint() {
|
||||
let good = BigUint::from(2u8).pow(253) - BigUint::from(1u8);
|
||||
let res = sanitize_biguint(&good, 253);
|
||||
assert!(!res.is_err());
|
||||
|
||||
let bad = BigUint::from(2u8).pow(253);
|
||||
let res = sanitize_biguint(&bad, 253);
|
||||
assert!(res.is_err());
|
||||
}
|
||||
449
src/verifier.rs
449
src/verifier.rs
@@ -1,23 +1,26 @@
|
||||
use crate::boolvec_to_u8vec;
|
||||
use super::ARITHMETIC_LABEL_SIZE;
|
||||
use crate::label::{LabelGenerator, Seed};
|
||||
use crate::utils::{compute_zero_sum_and_deltas, encrypt_arithmetic_labels, sanitize_biguint};
|
||||
use crate::{Delta, LabelsumHash, PlaintextHash, Proof, ZeroSum};
|
||||
use num::BigUint;
|
||||
|
||||
use super::{encrypt_arithmetic_labels, random_bigint, ARITHMETIC_LABEL_SIZE, POSEIDON_RATE};
|
||||
use crate::label::{LabelGenerator, LabelPair, Seed};
|
||||
use aes::{Aes128, NewBlockCipher};
|
||||
use cipher::{consts::U16, generic_array::GenericArray, BlockCipher, BlockEncrypt};
|
||||
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
|
||||
use rand::SeedableRng;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
// The PRG we use to generate arithmetic labels
|
||||
type Prg = ChaCha20Rng;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum VerifierError {
|
||||
ProvingKeyNotFound,
|
||||
FileSystemError,
|
||||
FileDoesNotExist,
|
||||
SnarkjsError,
|
||||
WrongProofCount,
|
||||
BigUintTooLarge,
|
||||
VerifyingBackendError,
|
||||
VerificationFailed,
|
||||
InternalError,
|
||||
}
|
||||
|
||||
/// Public inputs and a zk proof that needs to be verified.
|
||||
#[derive(Default)]
|
||||
pub struct VerificationInput {
|
||||
pub plaintext_hash: PlaintextHash,
|
||||
pub label_sum_hash: LabelsumHash,
|
||||
pub sum_of_zero_labels: ZeroSum,
|
||||
pub deltas: Vec<Delta>,
|
||||
pub proof: Proof,
|
||||
}
|
||||
|
||||
pub trait State {}
|
||||
@@ -25,51 +28,57 @@ pub trait State {}
|
||||
pub struct Setup {
|
||||
binary_labels: Vec<[u128; 2]>,
|
||||
}
|
||||
impl State for Setup {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ReceivePlaintextHashes {
|
||||
deltas: Vec<BigUint>,
|
||||
/// The sum of all arithmetic labels with the semantic value 0. One sum
|
||||
/// for each chunk of the plaintext.
|
||||
zero_sums: Vec<BigUint>,
|
||||
ciphertexts: Vec<[Vec<u8>; 2]>,
|
||||
/// The PRG seed from which all arithmetic labels were generated
|
||||
deltas: Vec<Delta>,
|
||||
zero_sums: Vec<ZeroSum>,
|
||||
ciphertexts: Vec<[[u8; 16]; 2]>,
|
||||
arith_label_seed: Seed,
|
||||
}
|
||||
impl State for ReceivePlaintextHashes {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ReceiveLabelsumHashes {
|
||||
deltas: Vec<BigUint>,
|
||||
zero_sums: Vec<BigUint>,
|
||||
// hashes for each chunk of Prover's plaintext
|
||||
plaintext_hashes: Vec<BigUint>,
|
||||
arith_label_seed: [u8; 32],
|
||||
deltas: Vec<Delta>,
|
||||
zero_sums: Vec<ZeroSum>,
|
||||
plaintext_hashes: Vec<PlaintextHash>,
|
||||
arith_label_seed: Seed,
|
||||
}
|
||||
impl State for ReceiveLabelsumHashes {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct VerifyMany {
|
||||
deltas: Vec<BigUint>,
|
||||
zero_sums: Vec<BigUint>,
|
||||
plaintext_hashes: Vec<BigUint>,
|
||||
labelsum_hashes: Vec<BigUint>,
|
||||
deltas: Vec<Delta>,
|
||||
zero_sums: Vec<ZeroSum>,
|
||||
plaintext_hashes: Vec<PlaintextHash>,
|
||||
labelsum_hashes: Vec<LabelsumHash>,
|
||||
}
|
||||
impl State for VerifyMany {}
|
||||
|
||||
pub struct VerificationSuccessfull {
|
||||
plaintext_hashes: Vec<BigUint>,
|
||||
plaintext_hashes: Vec<PlaintextHash>,
|
||||
}
|
||||
|
||||
impl State for Setup {}
|
||||
impl State for ReceivePlaintextHashes {}
|
||||
impl State for ReceiveLabelsumHashes {}
|
||||
impl State for VerifyMany {}
|
||||
impl State for VerificationSuccessfull {}
|
||||
|
||||
pub trait Verify {
|
||||
fn verify(
|
||||
&self,
|
||||
proof: Vec<u8>,
|
||||
deltas: Vec<String>,
|
||||
plaintext_hash: BigUint,
|
||||
labelsum_hash: BigUint,
|
||||
zero_sum: BigUint,
|
||||
) -> Result<bool, VerifierError>;
|
||||
/// Verifies the zk proof against public `input`s. Returns `true` on success,
|
||||
/// `false` otherwise.
|
||||
fn verify(&self, input: VerificationInput) -> Result<bool, VerifierError>;
|
||||
|
||||
/// The EC field size in bits. Verifier uses this to sanitize the `BigUint`s
|
||||
/// received from Prover.
|
||||
fn field_size(&self) -> usize;
|
||||
|
||||
/// Returns how many bits of plaintext we will pack into one field element.
|
||||
/// Normally, this should be [Verify::field_size] minus 1.
|
||||
fn useful_bits(&self) -> usize;
|
||||
|
||||
/// How many bits of [Plaintext] can fit into one [Chunk]. This does not
|
||||
/// include the [Salt] of the hash - which takes up the remaining least bits
|
||||
/// of the last field element of each chunk.
|
||||
fn chunk_size(&self) -> usize;
|
||||
}
|
||||
pub struct LabelsumVerifier<S = Setup>
|
||||
where
|
||||
@@ -80,6 +89,7 @@ where
|
||||
}
|
||||
|
||||
impl LabelsumVerifier {
|
||||
/// Returns the next expected state.
|
||||
pub fn new(
|
||||
binary_labels: Vec<[u128; 2]>,
|
||||
verifier: Box<dyn Verify>,
|
||||
@@ -92,40 +102,33 @@ impl LabelsumVerifier {
|
||||
}
|
||||
|
||||
impl LabelsumVerifier<Setup> {
|
||||
/// Generates arith. labels from a seed and encrypts them using binary labels
|
||||
/// as encryption keys.
|
||||
/// Generates arithmetic labels from a seed, computes the deltas, computes
|
||||
/// the sum of zero labels, encrypts arithmetic labels using binary
|
||||
/// labels as encryption keys.
|
||||
///
|
||||
/// Returns the next expected state.
|
||||
pub fn setup(self) -> Result<LabelsumVerifier<ReceivePlaintextHashes>, VerifierError> {
|
||||
let plaintext_bitsize = self.state.binary_labels.len();
|
||||
// Compute useful bits from the field prime
|
||||
let chunk_size = 253 * POSEIDON_RATE - 128;
|
||||
// count of chunks rounded up
|
||||
let chunk_count = (plaintext_bitsize + (chunk_size - 1)) / chunk_size;
|
||||
// There will be as many deltas as there are garbled circuit output
|
||||
// labels.
|
||||
let mut deltas: Vec<BigUint> = Vec::with_capacity(self.state.binary_labels.len());
|
||||
|
||||
// There will be as many zero_sums as there are chunks
|
||||
let mut zero_sums: Vec<BigUint> = Vec::with_capacity(chunk_count);
|
||||
let (label_pairs, seed) =
|
||||
LabelGenerator::generate(self.state.binary_labels.len(), ARITHMETIC_LABEL_SIZE);
|
||||
|
||||
// There will be as many deltas as there are plaintext bits.
|
||||
let mut deltas: Vec<BigUint> = Vec::with_capacity(plaintext_bitsize);
|
||||
let zero_sums: Vec<ZeroSum> = label_pairs
|
||||
.chunks(self.verifier.chunk_size())
|
||||
.map(|chunk_of_alabel_pairs| {
|
||||
let (zero_sum, deltas_in_chunk) =
|
||||
compute_zero_sum_and_deltas(chunk_of_alabel_pairs);
|
||||
deltas.extend(deltas_in_chunk);
|
||||
zero_sum
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Generate arithmetic label pairs and split them into chunks
|
||||
let generator = LabelGenerator::new();
|
||||
let (label_pairs, seed, _generator) =
|
||||
generator.generate(plaintext_bitsize, ARITHMETIC_LABEL_SIZE);
|
||||
let label_pair_chunks = label_pairs.chunks(chunk_size);
|
||||
|
||||
// Calculate deltas for all chunks and zero_sums for each chunk
|
||||
for chunk in label_pair_chunks {
|
||||
let mut zero_sum = BigUint::from_u8(0).unwrap();
|
||||
for label_pair in chunk {
|
||||
zero_sum += label_pair[0].clone();
|
||||
deltas.push(label_pair[1].clone() - label_pair[0].clone());
|
||||
}
|
||||
zero_sums.push(zero_sum);
|
||||
}
|
||||
|
||||
// encrypt each arithmetic label using a corresponding binary label as a key
|
||||
// place ciphertexts in an order based on binary label's p&p bit
|
||||
let ciphertexts = encrypt_arithmetic_labels(&label_pairs, &self.state.binary_labels);
|
||||
let ciphertexts = match encrypt_arithmetic_labels(&label_pairs, &self.state.binary_labels) {
|
||||
Ok(ct) => ct,
|
||||
Err(_) => return Err(VerifierError::InternalError),
|
||||
};
|
||||
|
||||
Ok(LabelsumVerifier {
|
||||
state: ReceivePlaintextHashes {
|
||||
@@ -140,12 +143,19 @@ impl LabelsumVerifier<Setup> {
|
||||
}
|
||||
|
||||
impl LabelsumVerifier<ReceivePlaintextHashes> {
|
||||
// receive hashes of plaintext and reveal the encrypted arithmetic labels
|
||||
/// Receives hashes of plaintext and returns the encrypted
|
||||
/// arithmetic labels and the next expected state.
|
||||
pub fn receive_plaintext_hashes(
|
||||
self,
|
||||
plaintext_hashes: Vec<BigUint>,
|
||||
) -> (Vec<[Vec<u8>; 2]>, LabelsumVerifier<ReceiveLabelsumHashes>) {
|
||||
(
|
||||
plaintext_hashes: Vec<PlaintextHash>,
|
||||
) -> Result<(Vec<[[u8; 16]; 2]>, LabelsumVerifier<ReceiveLabelsumHashes>), VerifierError> {
|
||||
for h in &plaintext_hashes {
|
||||
if sanitize_biguint(h, self.verifier.field_size()).is_err() {
|
||||
return Err(VerifierError::BigUintTooLarge);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
self.state.ciphertexts,
|
||||
LabelsumVerifier {
|
||||
state: ReceiveLabelsumHashes {
|
||||
@@ -156,18 +166,24 @@ impl LabelsumVerifier<ReceivePlaintextHashes> {
|
||||
},
|
||||
verifier: self.verifier,
|
||||
},
|
||||
)
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelsumVerifier<ReceiveLabelsumHashes> {
|
||||
// receive the hash commitment to the Prover's sum of labels and reveal
|
||||
// the arithmetic label seed
|
||||
/// Receives hashes of sums of labels and returns the arithmetic label [Seed]
|
||||
/// and the next expected state.
|
||||
pub fn receive_labelsum_hashes(
|
||||
self,
|
||||
labelsum_hashes: Vec<BigUint>,
|
||||
) -> (Seed, LabelsumVerifier<VerifyMany>) {
|
||||
(
|
||||
labelsum_hashes: Vec<LabelsumHash>,
|
||||
) -> Result<(Seed, LabelsumVerifier<VerifyMany>), VerifierError> {
|
||||
for h in &labelsum_hashes {
|
||||
if sanitize_biguint(h, self.verifier.field_size()).is_err() {
|
||||
return Err(VerifierError::BigUintTooLarge);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
self.state.arith_label_seed,
|
||||
LabelsumVerifier {
|
||||
state: VerifyMany {
|
||||
@@ -178,60 +194,27 @@ impl LabelsumVerifier<ReceiveLabelsumHashes> {
|
||||
},
|
||||
verifier: self.verifier,
|
||||
},
|
||||
)
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelsumVerifier<VerifyMany> {
|
||||
/// Verifies as many proofs as there are [Chunk]s of the plaintext. Returns
|
||||
/// the next expected state.
|
||||
pub fn verify_many(
|
||||
self,
|
||||
proofs: Vec<Vec<u8>>,
|
||||
mut self,
|
||||
proofs: Vec<Proof>,
|
||||
) -> Result<LabelsumVerifier<VerificationSuccessfull>, VerifierError> {
|
||||
// // Write public.json. The elements must be written in the exact order
|
||||
// // as below, that's the order snarkjs expects them to be in.
|
||||
|
||||
// the last chunk will be padded with zero plaintext. We also should pad
|
||||
// the deltas of the last chunk
|
||||
// TODO remove this hard-coding
|
||||
let useful_bits = 253;
|
||||
// the size of a chunk of plaintext not counting the salt
|
||||
let chunk_size = useful_bits * 16 - 128;
|
||||
let chunk_count = (self.state.deltas.len() + (chunk_size - 1)) / chunk_size;
|
||||
assert!(proofs.len() == chunk_count);
|
||||
|
||||
// pad deltas with 0 values to make their count a multiple of a chunk size
|
||||
let delta_pad_count = chunk_size * chunk_count - self.state.deltas.len();
|
||||
let mut deltas = self.state.deltas.clone();
|
||||
deltas.extend(vec![BigUint::from_u8(0).unwrap(); delta_pad_count]);
|
||||
let deltas_chunks: Vec<&[BigUint]> = deltas.chunks(chunk_size).collect();
|
||||
|
||||
for count in 0..chunk_count {
|
||||
// There are as many deltas as there are bits in the chunk of the
|
||||
// plaintext (not counting the salt)
|
||||
let delta_str: Vec<String> =
|
||||
deltas_chunks[count].iter().map(|v| v.to_string()).collect();
|
||||
|
||||
let plaintext_hash = self.state.plaintext_hashes[count].clone();
|
||||
let labelsum_hash = self.state.labelsum_hashes[count].clone();
|
||||
let zero_sum = self.state.zero_sums[count].clone();
|
||||
|
||||
let res = self.verifier.verify(
|
||||
proofs[count].clone(),
|
||||
delta_str,
|
||||
plaintext_hash,
|
||||
labelsum_hash,
|
||||
zero_sum,
|
||||
);
|
||||
// checking both for good measure
|
||||
if res.is_err() {
|
||||
return Err(VerifierError::VerificationFailed);
|
||||
}
|
||||
// shouldn't get here if there was an error, but will check anyway
|
||||
if res.unwrap() != true {
|
||||
let inputs = self.create_verification_inputs(proofs)?;
|
||||
for input in inputs {
|
||||
let res = self.verifier.verify(input)?;
|
||||
if res != true {
|
||||
// we will never get here since "?" takes care of the
|
||||
// verification error. Still, it is good to have this check
|
||||
// just in case.
|
||||
return Err(VerifierError::VerificationFailed);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(LabelsumVerifier {
|
||||
state: VerificationSuccessfull {
|
||||
plaintext_hashes: self.state.plaintext_hashes,
|
||||
@@ -239,4 +222,208 @@ impl LabelsumVerifier<VerifyMany> {
|
||||
verifier: self.verifier,
|
||||
})
|
||||
}
|
||||
|
||||
/// Construct public inputs for the zk circuit for each [Chunk].
|
||||
fn create_verification_inputs(
|
||||
&mut self,
|
||||
proofs: Vec<Proof>,
|
||||
) -> Result<Vec<VerificationInput>, VerifierError> {
|
||||
// How many chunks of plaintext are there? ( == how many zk proofs to expect)
|
||||
// The amount of deltas corresponds to the amount of bits in the plaintext.
|
||||
// Round up the chunk count.
|
||||
let chunk_count = (self.state.deltas.len() + (self.verifier.chunk_size() - 1))
|
||||
/ self.verifier.chunk_size();
|
||||
|
||||
if proofs.len() != chunk_count {
|
||||
return Err(VerifierError::WrongProofCount);
|
||||
}
|
||||
|
||||
// Since the last chunk of plaintext is padded with zero bits, we also zero-pad
|
||||
// the corresponding deltas of the last chunk to the size of a chunk.
|
||||
let delta_pad_count = self.verifier.chunk_size() * chunk_count - self.state.deltas.len();
|
||||
let mut deltas = self.state.deltas.clone();
|
||||
deltas.extend(vec![0u8.into(); delta_pad_count]);
|
||||
|
||||
let chunks_of_deltas = deltas
|
||||
.chunks(self.verifier.chunk_size())
|
||||
.map(|i| i.to_vec())
|
||||
.collect::<Vec<Vec<_>>>();
|
||||
|
||||
Ok((0..chunk_count)
|
||||
.map(|i| VerificationInput {
|
||||
plaintext_hash: self.state.plaintext_hashes[i].clone(),
|
||||
label_sum_hash: self.state.labelsum_hashes[i].clone(),
|
||||
sum_of_zero_labels: self.state.zero_sums[i].clone(),
|
||||
deltas: chunks_of_deltas[i].clone(),
|
||||
proof: proofs[i].clone(),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::verifier::LabelsumVerifier;
|
||||
use crate::verifier::ReceiveLabelsumHashes;
|
||||
use crate::verifier::ReceivePlaintextHashes;
|
||||
use crate::verifier::VerificationInput;
|
||||
use crate::verifier::VerifierError;
|
||||
use crate::verifier::Verify;
|
||||
use crate::verifier::VerifyMany;
|
||||
use crate::Proof;
|
||||
use num::BigUint;
|
||||
|
||||
/// The verifier who implements `Verify` with the correct values
|
||||
struct CorrectTestVerifier {}
|
||||
impl Verify for CorrectTestVerifier {
|
||||
fn verify(&self, _input: VerificationInput) -> Result<bool, VerifierError> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn field_size(&self) -> usize {
|
||||
254
|
||||
}
|
||||
|
||||
fn useful_bits(&self) -> usize {
|
||||
253
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
3670
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Provide `BigUint` larger than useful_bits() and trigger
|
||||
/// [VerifierError::BigUintTooLarge]
|
||||
fn test_error_biguint_too_large() {
|
||||
// test receive_plaintext_hashes()
|
||||
let lsv = LabelsumVerifier {
|
||||
state: ReceivePlaintextHashes::default(),
|
||||
verifier: Box::new(CorrectTestVerifier {}),
|
||||
};
|
||||
|
||||
let mut hashes: Vec<BigUint> = (0..100).map(|i| BigUint::from(i as u64)).collect();
|
||||
hashes[50] = BigUint::from(2u8).pow(lsv.verifier.field_size() as u32);
|
||||
let res = lsv.receive_plaintext_hashes(hashes);
|
||||
|
||||
assert_eq!(res.err().unwrap(), VerifierError::BigUintTooLarge);
|
||||
|
||||
// test receive_labelsum_hashes
|
||||
let lsv = LabelsumVerifier {
|
||||
state: ReceiveLabelsumHashes::default(),
|
||||
verifier: Box::new(CorrectTestVerifier {}),
|
||||
};
|
||||
|
||||
let mut plaintext_hashes: Vec<BigUint> =
|
||||
(0..100).map(|i| BigUint::from(i as u64)).collect();
|
||||
plaintext_hashes[50] = BigUint::from(2u8).pow(lsv.verifier.field_size() as u32);
|
||||
let res = lsv.receive_labelsum_hashes(plaintext_hashes);
|
||||
|
||||
assert_eq!(res.err().unwrap(), VerifierError::BigUintTooLarge);
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Provide too many/too few proofs and trigger [VerifierError::WrongProofCount]
|
||||
fn test_error_wrong_proof_count() {
|
||||
// 3 chunks
|
||||
let lsv = LabelsumVerifier {
|
||||
state: VerifyMany {
|
||||
deltas: vec![0u8.into(); 3670 * 2 + 1],
|
||||
..Default::default()
|
||||
},
|
||||
verifier: Box::new(CorrectTestVerifier {}),
|
||||
};
|
||||
// 4 proofs
|
||||
let res = lsv.verify_many(vec![Proof::default(); 4]);
|
||||
|
||||
assert_eq!(res.err().unwrap(), VerifierError::WrongProofCount);
|
||||
|
||||
// 3 chunks
|
||||
let lsv = LabelsumVerifier {
|
||||
state: VerifyMany {
|
||||
deltas: vec![0u8.into(); 3670 * 2 + 1],
|
||||
..Default::default()
|
||||
},
|
||||
verifier: Box::new(CorrectTestVerifier {}),
|
||||
};
|
||||
// 2 proofs
|
||||
let res = lsv.verify_many(vec![Proof::default(); 2]);
|
||||
|
||||
assert_eq!(res.err().unwrap(), VerifierError::WrongProofCount);
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Returns `false` when attempting to verify and triggers
|
||||
/// [VerifierError::VerificationFailed]
|
||||
fn test_error_verification_failed() {
|
||||
struct TestVerifier {}
|
||||
impl Verify for TestVerifier {
|
||||
fn verify(&self, _input: VerificationInput) -> Result<bool, VerifierError> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn field_size(&self) -> usize {
|
||||
254
|
||||
}
|
||||
|
||||
fn useful_bits(&self) -> usize {
|
||||
253
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
3670
|
||||
}
|
||||
}
|
||||
|
||||
let lsv = LabelsumVerifier {
|
||||
state: VerifyMany {
|
||||
deltas: vec![0u8.into(); 3670 * 2 - 1],
|
||||
zero_sums: vec![0u8.into(); 2],
|
||||
plaintext_hashes: vec![0u8.into(); 2],
|
||||
labelsum_hashes: vec![0u8.into(); 2],
|
||||
},
|
||||
verifier: Box::new(TestVerifier {}),
|
||||
};
|
||||
let res = lsv.verify_many(vec![Proof::default(); 2]);
|
||||
|
||||
assert_eq!(res.err().unwrap(), VerifierError::VerificationFailed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Returns some other error not related to the verification result when
|
||||
/// attempting to verify and checks that the error propagates.
|
||||
fn test_verification_error() {
|
||||
struct TestVerifier {}
|
||||
impl Verify for TestVerifier {
|
||||
fn verify(&self, _input: VerificationInput) -> Result<bool, VerifierError> {
|
||||
Err(VerifierError::VerifyingBackendError)
|
||||
}
|
||||
|
||||
fn field_size(&self) -> usize {
|
||||
254
|
||||
}
|
||||
|
||||
fn useful_bits(&self) -> usize {
|
||||
253
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
3670
|
||||
}
|
||||
}
|
||||
|
||||
let lsv = LabelsumVerifier {
|
||||
state: VerifyMany {
|
||||
deltas: vec![0u8.into(); 3670 * 2 - 1],
|
||||
zero_sums: vec![0u8.into(); 2],
|
||||
plaintext_hashes: vec![0u8.into(); 2],
|
||||
labelsum_hashes: vec![0u8.into(); 2],
|
||||
},
|
||||
verifier: Box::new(TestVerifier {}),
|
||||
};
|
||||
let res = lsv.verify_many(vec![Proof::default(); 2]);
|
||||
|
||||
assert_eq!(res.err().unwrap(), VerifierError::VerifyingBackendError);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user