mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-09 14:08:04 -05:00
model encrypt test, "optimized" --> "precision"
This commit is contained in:
180
test/circuits/encrypted_mnist_latest_test.circom
Normal file
180
test/circuits/encrypted_mnist_latest_test.circom
Normal file
@@ -0,0 +1,180 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/Conv2D.circom";
|
||||
include "../../circuits/Dense.circom";
|
||||
include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/Poly.circom";
|
||||
include "../../circuits/AveragePooling2D.circom";
|
||||
include "../../circuits/BatchNormalization2D.circom";
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
include "../../circuits/crypto/encrypt.circom";
|
||||
|
||||
template encrypted_mnist_latest() {
|
||||
signal input shared_key;
|
||||
signal input in[28][28][1];
|
||||
signal input conv2d_1_weights[3][3][1][4];
|
||||
signal input conv2d_1_bias[4];
|
||||
signal input bn_1_a[4];
|
||||
signal input bn_1_b[4];
|
||||
signal input conv2d_2_weights[3][3][4][8];
|
||||
signal input conv2d_2_bias[8];
|
||||
signal input bn_2_a[8];
|
||||
signal input bn_2_b[8];
|
||||
signal input dense_weights[200][10];
|
||||
signal input dense_bias[10];
|
||||
signal output out;
|
||||
signal output message[3*3*1*4+4+4+4+3*3*4*8+8+8+8+200*10+10+1];
|
||||
|
||||
component enc = EncryptBits(3*3*1*4+4+4+4+3*3*4*8+8+8+8+200*10+10);
|
||||
enc.shared_key <== shared_key;
|
||||
var idx = 0;
|
||||
|
||||
component conv2d_1 = Conv2D(28,28,1,4,3,1);
|
||||
component bn_1 = BatchNormalization2D(26,26,4);
|
||||
component poly_1[26][26][4];
|
||||
component avg2d_1 = AveragePooling2D(26,26,4,2,2,25);
|
||||
component conv2d_2 = Conv2D(13,13,4,8,3,1);
|
||||
component bn_2 = BatchNormalization2D(11,11,8);
|
||||
component poly_2[11][11][8];
|
||||
component avg2d_2 = AveragePooling2D(11,11,8,2,2,25);
|
||||
component flatten = Flatten2D(5,5,8);
|
||||
component dense = Dense(200,10);
|
||||
component argmax = ArgMax(10);
|
||||
|
||||
for (var i=0; i<28; i++) {
|
||||
for (var j=0; j<28; j++) {
|
||||
conv2d_1.in[i][j][0] <== in[i][j][0];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<3; i++) {
|
||||
for (var j=0; j<3; j++) {
|
||||
for (var m=0; m<4; m++) {
|
||||
conv2d_1.weights[i][j][0][m] <== conv2d_1_weights[i][j][0][m];
|
||||
enc.plaintext[idx] <== conv2d_1_weights[i][j][0][m];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var m=0; m<4; m++) {
|
||||
conv2d_1.bias[m] <== conv2d_1_bias[m];
|
||||
enc.plaintext[idx] <== conv2d_1_bias[m];
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (var k=0; k<4; k++) {
|
||||
bn_1.a[k] <== bn_1_a[k];
|
||||
enc.plaintext[idx] <== bn_1_a[k];
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (var k=0; k<4; k++) {
|
||||
bn_1.b[k] <== bn_1_b[k];
|
||||
enc.plaintext[idx] <== bn_1_b[k];
|
||||
idx++;
|
||||
for (var i=0; i<26; i++) {
|
||||
for (var j=0; j<26; j++) {
|
||||
bn_1.in[i][j][k] <== conv2d_1.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<26; i++) {
|
||||
for (var j=0; j<26; j++) {
|
||||
for (var k=0; k<4; k++) {
|
||||
poly_1[i][j][k] = Poly(10**6);
|
||||
poly_1[i][j][k].in <== bn_1.out[i][j][k];
|
||||
avg2d_1.in[i][j][k] <== poly_1[i][j][k].out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<13; i++) {
|
||||
for (var j=0; j<13; j++) {
|
||||
for (var k=0; k<4; k++) {
|
||||
conv2d_2.in[i][j][k] <== avg2d_1.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<3; i++) {
|
||||
for (var j=0; j<3; j++) {
|
||||
for (var k=0; k<4; k++) {
|
||||
for (var m=0; m<8; m++) {
|
||||
conv2d_2.weights[i][j][k][m] <== conv2d_2_weights[i][j][k][m];
|
||||
enc.plaintext[idx] <== conv2d_2_weights[i][j][k][m];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var m=0; m<8; m++) {
|
||||
conv2d_2.bias[m] <== conv2d_2_bias[m];
|
||||
enc.plaintext[idx] <== conv2d_2_bias[m];
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (var k=0; k<8; k++) {
|
||||
bn_2.a[k] <== bn_2_a[k];
|
||||
enc.plaintext[idx] <== bn_2_a[k];
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (var k=0; k<8; k++) {
|
||||
bn_2.b[k] <== bn_2_b[k];
|
||||
enc.plaintext[idx] <== bn_2_b[k];
|
||||
idx++;
|
||||
for (var i=0; i<11; i++) {
|
||||
for (var j=0; j<11; j++) {
|
||||
bn_2.in[i][j][k] <== conv2d_2.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<11; i++) {
|
||||
for (var j=0; j<11; j++) {
|
||||
for (var k=0; k<8; k++) {
|
||||
poly_2[i][j][k] = Poly(10**18);
|
||||
poly_2[i][j][k].in <== bn_2.out[i][j][k];
|
||||
avg2d_2.in[i][j][k] <== poly_2[i][j][k].out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<5; i++) {
|
||||
for (var j=0; j<5; j++) {
|
||||
for (var k=0; k<8; k++) {
|
||||
flatten.in[i][j][k] <== avg2d_2.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<200; i++) {
|
||||
dense.in[i] <== flatten.out[i];
|
||||
for (var j=0; j<10; j++) {
|
||||
dense.weights[i][j] <== dense_weights[i][j];
|
||||
enc.plaintext[idx] <== dense_weights[i][j];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
dense.bias[i] <== dense_bias[i];
|
||||
enc.plaintext[idx] <== dense_bias[i];
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
argmax.in[i] <== dense.out[i];
|
||||
}
|
||||
|
||||
out <== argmax.out;
|
||||
|
||||
for (var i=0; i<3*3*1*4+4+4+4+3*3*4*8+8+8+8+200*10+10+1; i++) {
|
||||
message[i] <== enc.out[i];
|
||||
}
|
||||
}
|
||||
|
||||
component main = encrypted_mnist_latest();
|
||||
@@ -8,7 +8,7 @@ include "../../circuits/AveragePooling2D.circom";
|
||||
include "../../circuits/BatchNormalization2D.circom";
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
|
||||
template mnist_latest_optimized() {
|
||||
template mnist_latest_precision() {
|
||||
signal input in[28][28][1];
|
||||
signal input conv2d_1_weights[3][3][1][4];
|
||||
signal input conv2d_1_bias[4];
|
||||
@@ -136,4 +136,4 @@ template mnist_latest_optimized() {
|
||||
out <== argmax.out;
|
||||
}
|
||||
|
||||
component main = mnist_latest_optimized();
|
||||
component main = mnist_latest_precision();
|
||||
@@ -206,5 +206,60 @@ describe("crypto circuits test", function () {
|
||||
|
||||
// TODO: encrypt a model
|
||||
it("encrypt entire model in circom, decrypt in js", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "encrypted_mnist_latest_test.circom"));
|
||||
const json = require("../models/mnist_latest_input.json");
|
||||
|
||||
let INPUT = {};
|
||||
let plaintext = [
|
||||
...json['conv2d_1_weights'],
|
||||
...json['conv2d_1_bias'],
|
||||
...json['bn_1_a'],
|
||||
...json['bn_1_b'],
|
||||
...json['conv2d_2_weights'],
|
||||
...json['conv2d_2_bias'],
|
||||
...json['bn_2_a'],
|
||||
...json['bn_2_b'],
|
||||
...json['dense_weights'],
|
||||
...json['dense_bias'],
|
||||
];
|
||||
|
||||
for (const [key, value] of Object.entries(json)) {
|
||||
if (Array.isArray(value)) {
|
||||
let tmpArray = [];
|
||||
for (let i = 0; i < value.flat().length; i++) {
|
||||
tmpArray.push(Fr.e(value.flat()[i]));
|
||||
}
|
||||
INPUT[key] = tmpArray;
|
||||
} else {
|
||||
INPUT[key] = Fr.e(value);
|
||||
}
|
||||
}
|
||||
|
||||
const keypair = new Keypair();
|
||||
const keypair2 = new Keypair();
|
||||
|
||||
const ecdhSharedKey = Keypair.genEcdhSharedKey(
|
||||
keypair.privKey,
|
||||
keypair2.pubKey,
|
||||
);
|
||||
|
||||
INPUT['shared_key'] = ecdhSharedKey.toString();
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
|
||||
|
||||
const ciphertext = {
|
||||
iv: witness[2],
|
||||
data: witness.slice(3,2373),
|
||||
}
|
||||
|
||||
decryptedText = decrypt(ciphertext, ecdhSharedKey);
|
||||
|
||||
for (let i=0; i<2370; i++) {
|
||||
assert(Fr.eq(Fr.e(decryptedText[i]), Fr.e(plaintext[i])));
|
||||
}
|
||||
|
||||
});
|
||||
});
|
||||
@@ -10,13 +10,13 @@ const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
const json = require("../models/mnist_latest_optimized_input.json");
|
||||
const json = require("../models/mnist_latest_precision_input.json");
|
||||
|
||||
describe("mnist latest optimized test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
it("should return correct output", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_latest_optimized_test.circom"));
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_latest_precision_test.circom"));
|
||||
await circuit.loadConstraints();
|
||||
console.log(circuit.nVars, circuit.constraints.length);
|
||||
|
||||
Reference in New Issue
Block a user