model encrypt test, "optimized" --> "precision"

This commit is contained in:
Cathie So
2022-12-03 15:37:53 +08:00
parent 822c271259
commit 184c90599f
6 changed files with 240 additions and 5 deletions

View File

@@ -0,0 +1,180 @@
pragma circom 2.0.3;
include "../../circuits/Conv2D.circom";
include "../../circuits/Dense.circom";
include "../../circuits/ArgMax.circom";
include "../../circuits/Poly.circom";
include "../../circuits/AveragePooling2D.circom";
include "../../circuits/BatchNormalization2D.circom";
include "../../circuits/Flatten2D.circom";
include "../../circuits/crypto/encrypt.circom";
template encrypted_mnist_latest() {
signal input shared_key;
signal input in[28][28][1];
signal input conv2d_1_weights[3][3][1][4];
signal input conv2d_1_bias[4];
signal input bn_1_a[4];
signal input bn_1_b[4];
signal input conv2d_2_weights[3][3][4][8];
signal input conv2d_2_bias[8];
signal input bn_2_a[8];
signal input bn_2_b[8];
signal input dense_weights[200][10];
signal input dense_bias[10];
signal output out;
signal output message[3*3*1*4+4+4+4+3*3*4*8+8+8+8+200*10+10+1];
component enc = EncryptBits(3*3*1*4+4+4+4+3*3*4*8+8+8+8+200*10+10);
enc.shared_key <== shared_key;
var idx = 0;
component conv2d_1 = Conv2D(28,28,1,4,3,1);
component bn_1 = BatchNormalization2D(26,26,4);
component poly_1[26][26][4];
component avg2d_1 = AveragePooling2D(26,26,4,2,2,25);
component conv2d_2 = Conv2D(13,13,4,8,3,1);
component bn_2 = BatchNormalization2D(11,11,8);
component poly_2[11][11][8];
component avg2d_2 = AveragePooling2D(11,11,8,2,2,25);
component flatten = Flatten2D(5,5,8);
component dense = Dense(200,10);
component argmax = ArgMax(10);
for (var i=0; i<28; i++) {
for (var j=0; j<28; j++) {
conv2d_1.in[i][j][0] <== in[i][j][0];
}
}
for (var i=0; i<3; i++) {
for (var j=0; j<3; j++) {
for (var m=0; m<4; m++) {
conv2d_1.weights[i][j][0][m] <== conv2d_1_weights[i][j][0][m];
enc.plaintext[idx] <== conv2d_1_weights[i][j][0][m];
idx++;
}
}
}
for (var m=0; m<4; m++) {
conv2d_1.bias[m] <== conv2d_1_bias[m];
enc.plaintext[idx] <== conv2d_1_bias[m];
idx++;
}
for (var k=0; k<4; k++) {
bn_1.a[k] <== bn_1_a[k];
enc.plaintext[idx] <== bn_1_a[k];
idx++;
}
for (var k=0; k<4; k++) {
bn_1.b[k] <== bn_1_b[k];
enc.plaintext[idx] <== bn_1_b[k];
idx++;
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
bn_1.in[i][j][k] <== conv2d_1.out[i][j][k];
}
}
}
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
for (var k=0; k<4; k++) {
poly_1[i][j][k] = Poly(10**6);
poly_1[i][j][k].in <== bn_1.out[i][j][k];
avg2d_1.in[i][j][k] <== poly_1[i][j][k].out;
}
}
}
for (var i=0; i<13; i++) {
for (var j=0; j<13; j++) {
for (var k=0; k<4; k++) {
conv2d_2.in[i][j][k] <== avg2d_1.out[i][j][k];
}
}
}
for (var i=0; i<3; i++) {
for (var j=0; j<3; j++) {
for (var k=0; k<4; k++) {
for (var m=0; m<8; m++) {
conv2d_2.weights[i][j][k][m] <== conv2d_2_weights[i][j][k][m];
enc.plaintext[idx] <== conv2d_2_weights[i][j][k][m];
idx++;
}
}
}
}
for (var m=0; m<8; m++) {
conv2d_2.bias[m] <== conv2d_2_bias[m];
enc.plaintext[idx] <== conv2d_2_bias[m];
idx++;
}
for (var k=0; k<8; k++) {
bn_2.a[k] <== bn_2_a[k];
enc.plaintext[idx] <== bn_2_a[k];
idx++;
}
for (var k=0; k<8; k++) {
bn_2.b[k] <== bn_2_b[k];
enc.plaintext[idx] <== bn_2_b[k];
idx++;
for (var i=0; i<11; i++) {
for (var j=0; j<11; j++) {
bn_2.in[i][j][k] <== conv2d_2.out[i][j][k];
}
}
}
for (var i=0; i<11; i++) {
for (var j=0; j<11; j++) {
for (var k=0; k<8; k++) {
poly_2[i][j][k] = Poly(10**18);
poly_2[i][j][k].in <== bn_2.out[i][j][k];
avg2d_2.in[i][j][k] <== poly_2[i][j][k].out;
}
}
}
for (var i=0; i<5; i++) {
for (var j=0; j<5; j++) {
for (var k=0; k<8; k++) {
flatten.in[i][j][k] <== avg2d_2.out[i][j][k];
}
}
}
for (var i=0; i<200; i++) {
dense.in[i] <== flatten.out[i];
for (var j=0; j<10; j++) {
dense.weights[i][j] <== dense_weights[i][j];
enc.plaintext[idx] <== dense_weights[i][j];
idx++;
}
}
for (var i=0; i<10; i++) {
dense.bias[i] <== dense_bias[i];
enc.plaintext[idx] <== dense_bias[i];
idx++;
}
for (var i=0; i<10; i++) {
argmax.in[i] <== dense.out[i];
}
out <== argmax.out;
for (var i=0; i<3*3*1*4+4+4+4+3*3*4*8+8+8+8+200*10+10+1; i++) {
message[i] <== enc.out[i];
}
}
component main = encrypted_mnist_latest();

View File

@@ -8,7 +8,7 @@ include "../../circuits/AveragePooling2D.circom";
include "../../circuits/BatchNormalization2D.circom";
include "../../circuits/Flatten2D.circom";
template mnist_latest_optimized() {
template mnist_latest_precision() {
signal input in[28][28][1];
signal input conv2d_1_weights[3][3][1][4];
signal input conv2d_1_bias[4];
@@ -136,4 +136,4 @@ template mnist_latest_optimized() {
out <== argmax.out;
}
component main = mnist_latest_optimized();
component main = mnist_latest_precision();

View File

@@ -206,5 +206,60 @@ describe("crypto circuits test", function () {
// TODO: encrypt a model
it("encrypt entire model in circom, decrypt in js", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "encrypted_mnist_latest_test.circom"));
const json = require("../models/mnist_latest_input.json");
let INPUT = {};
let plaintext = [
...json['conv2d_1_weights'],
...json['conv2d_1_bias'],
...json['bn_1_a'],
...json['bn_1_b'],
...json['conv2d_2_weights'],
...json['conv2d_2_bias'],
...json['bn_2_a'],
...json['bn_2_b'],
...json['dense_weights'],
...json['dense_bias'],
];
for (const [key, value] of Object.entries(json)) {
if (Array.isArray(value)) {
let tmpArray = [];
for (let i = 0; i < value.flat().length; i++) {
tmpArray.push(Fr.e(value.flat()[i]));
}
INPUT[key] = tmpArray;
} else {
INPUT[key] = Fr.e(value);
}
}
const keypair = new Keypair();
const keypair2 = new Keypair();
const ecdhSharedKey = Keypair.genEcdhSharedKey(
keypair.privKey,
keypair2.pubKey,
);
INPUT['shared_key'] = ecdhSharedKey.toString();
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
const ciphertext = {
iv: witness[2],
data: witness.slice(3,2373),
}
decryptedText = decrypt(ciphertext, ecdhSharedKey);
for (let i=0; i<2370; i++) {
assert(Fr.eq(Fr.e(decryptedText[i]), Fr.e(plaintext[i])));
}
});
});

View File

@@ -10,13 +10,13 @@ const Fr = new F1Field(exports.p);
const assert = chai.assert;
const json = require("../models/mnist_latest_optimized_input.json");
const json = require("../models/mnist_latest_precision_input.json");
describe("mnist latest optimized test", function () {
this.timeout(100000000);
it("should return correct output", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_latest_optimized_test.circom"));
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_latest_precision_test.circom"));
await circuit.loadConstraints();
console.log(circuit.nVars, circuit.constraints.length);