Encrypt/decrypt multiple, Flatten2D into models

This commit is contained in:
Cathie So
2022-11-30 00:04:30 +08:00
parent 444c98dc1b
commit 822c271259
8 changed files with 130 additions and 43 deletions

View File

@@ -0,0 +1,5 @@
pragma circom 2.0.3;
include "../../circuits/crypto/encrypt.circom";
component main = DecryptBits(1000);

View File

@@ -0,0 +1,5 @@
pragma circom 2.0.3;
include "../../circuits/crypto/encrypt.circom";
component main = EncryptBits(1000);

View File

@@ -5,6 +5,7 @@ include "../../circuits/Dense.circom";
include "../../circuits/ArgMax.circom";
include "../../circuits/Poly.circom";
include "../../circuits/SumPooling2D.circom";
include "../../circuits/Flatten2D.circom";
template mnist_convnet() {
signal input in[28][28][1];
@@ -22,6 +23,7 @@ template mnist_convnet() {
component conv2d_2 = Conv2D(13,13,4,8,3,1);
component poly_2[11][11][8];
component sum2d_2 = SumPooling2D(11,11,8,2,2);
component flatten = Flatten2D(5,5,8);
component dense = Dense(200,10);
component argmax = ArgMax(10);
@@ -79,20 +81,21 @@ template mnist_convnet() {
}
}
var idx = 0;
for (var i=0; i<5; i++) {
for (var j=0; j<5; j++) {
for (var k=0; k<8; k++) {
dense.in[idx] <== sum2d_2.out[i][j][k];
for (var m=0; m<10; m++) {
dense.weights[idx][m] <== dense_weights[idx][m];
}
idx++;
flatten.in[i][j][k] <== sum2d_2.out[i][j][k];
}
}
}
for (var i=0; i<200; i++) {
dense.in[i] <== flatten.out[i];
for (var j=0; j<10; j++) {
dense.weights[i][j] <== dense_weights[i][j];
}
}
for (var i=0; i<10; i++) {
dense.bias[i] <== dense_bias[i];
}

View File

@@ -6,6 +6,7 @@ include "../../circuits/ArgMax.circom";
include "../../circuits/ReLU.circom";
include "../../circuits/AveragePooling2D.circom";
include "../../circuits/BatchNormalization2D.circom";
include "../../circuits/Flatten2D.circom";
template mnist_latest_optimized() {
signal input in[28][28][1];
@@ -29,6 +30,7 @@ template mnist_latest_optimized() {
component bn_2 = BatchNormalization2D(11,11,8);
component relu_2[11][11][8];
component avg2d_2 = AveragePooling2D(11,11,8,2,2,25);
component flatten = Flatten2D(5,5,8);
component dense = Dense(200,10);
component argmax = ArgMax(10);
@@ -107,20 +109,21 @@ template mnist_latest_optimized() {
}
}
var idx = 0;
for (var i=0; i<5; i++) {
for (var j=0; j<5; j++) {
for (var k=0; k<8; k++) {
dense.in[idx] <== avg2d_2.out[i][j][k];
for (var m=0; m<10; m++) {
dense.weights[idx][m] <== dense_weights[idx][m];
}
idx++;
flatten.in[i][j][k] <== avg2d_2.out[i][j][k];
}
}
}
for (var i=0; i<200; i++) {
dense.in[i] <== flatten.out[i];
for (var j=0; j<10; j++) {
dense.weights[i][j] <== dense_weights[i][j];
}
}
for (var i=0; i<10; i++) {
dense.bias[i] <== dense_bias[i];
}

View File

@@ -6,6 +6,7 @@ include "../../circuits/ArgMax.circom";
include "../../circuits/Poly.circom";
include "../../circuits/AveragePooling2D.circom";
include "../../circuits/BatchNormalization2D.circom";
include "../../circuits/Flatten2D.circom";
template mnist_latest() {
signal input in[28][28][1];
@@ -29,6 +30,7 @@ template mnist_latest() {
component bn_2 = BatchNormalization2D(11,11,8);
component poly_2[11][11][8];
component avg2d_2 = AveragePooling2D(11,11,8,2,2,25);
component flatten = Flatten2D(5,5,8);
component dense = Dense(200,10);
component argmax = ArgMax(10);
@@ -107,20 +109,21 @@ template mnist_latest() {
}
}
var idx = 0;
for (var i=0; i<5; i++) {
for (var j=0; j<5; j++) {
for (var k=0; k<8; k++) {
dense.in[idx] <== avg2d_2.out[i][j][k];
for (var m=0; m<10; m++) {
dense.weights[idx][m] <== dense_weights[idx][m];
}
idx++;
flatten.in[i][j][k] <== avg2d_2.out[i][j][k];
}
}
}
for (var i=0; i<200; i++) {
dense.in[i] <== flatten.out[i];
for (var j=0; j<10; j++) {
dense.weights[i][j] <== dense_weights[i][j];
}
}
for (var i=0; i<10; i++) {
dense.bias[i] <== dense_bias[i];
}

View File

@@ -4,6 +4,7 @@ include "../../circuits/Conv2D.circom";
include "../../circuits/Dense.circom";
include "../../circuits/ArgMax.circom";
include "../../circuits/Poly.circom";
include "../../circuits/Flatten2D.circom";
template mnist_poly() {
signal input in[28][28][1];
@@ -14,6 +15,7 @@ template mnist_poly() {
signal output out;
component conv2d = Conv2D(28,28,1,1,3,1);
component flatten = Flatten2D(26,26,1);
component poly[26*26];
component dense = Dense(676,10);
component argmax = ArgMax(10);
@@ -32,17 +34,18 @@ template mnist_poly() {
conv2d.bias[0] <== conv2d_bias[0];
var idx = 0;
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
poly[idx] = Poly(10**18);
poly[idx].in <== conv2d.out[i][j][0];
dense.in[idx] <== poly[idx].out;
for (var k=0; k<10; k++) {
dense.weights[idx][k] <== dense_weights[idx][k];
}
idx++;
flatten.in[i][j][0] <== conv2d.out[i][j][0];
}
}
for (var i=0; i<26*26; i++) {
poly[i] = Poly(10**18);
poly[i].in <== flatten.out[i];
dense.in[i] <== poly[i].out;
for (var j=0; j<10; j++) {
dense.weights[i][j] <== dense_weights[i][j];
}
}

View File

@@ -4,6 +4,7 @@ include "../../circuits/Conv2D.circom";
include "../../circuits/Dense.circom";
include "../../circuits/ArgMax.circom";
include "../../circuits/ReLU.circom";
include "../../circuits/Flatten2D.circom";
template mnist() {
signal input in[28][28][1];
@@ -14,6 +15,7 @@ template mnist() {
signal output out;
component conv2d = Conv2D(28,28,1,1,3,1);
component flatten = Flatten2D(26,26,1);
component relu[26*26];
component dense = Dense(676,10);
component argmax = ArgMax(10);
@@ -32,17 +34,18 @@ template mnist() {
conv2d.bias[0] <== conv2d_bias[0];
var idx = 0;
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
relu[idx] = ReLU();
relu[idx].in <== conv2d.out[i][j][0];
dense.in[idx] <== relu[idx].out;
for (var k=0; k<10; k++) {
dense.weights[idx][k] <== dense_weights[idx][k];
}
idx++;
flatten.in[i][j][0] <== conv2d.out[i][j][0];
}
}
for (var i=0; i<26*26; i++) {
relu[i] = ReLU();
relu[i].in <== flatten.out[i];
dense.in[i] <== relu[i].out;
for (var j=0; j<10; j++) {
dense.weights[i][j] <== dense_weights[i][j];
}
}

View File

@@ -34,7 +34,7 @@ describe("crypto circuits test", function () {
assert(Fr.eq(Fr.e(witness[2]), Fr.e(keypair.pubKey.rawPubKey[1])));
});
it("ecdh shared key test", async () => {
it("ecdh full circuit test", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "ecdh_test.circom"));
const keypair = new Keypair();
@@ -142,7 +142,69 @@ describe("crypto circuits test", function () {
assert(Fr.eq(Fr.e(witness[1]), Fr.e(plaintext)));
});
// TODO: encrypt multiple
// TODO: decrypt multiple
it("encrypt multiple in circom, decrypt in js", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "encryptMultiple_test.circom"));
const keypair = new Keypair();
const keypair2 = new Keypair();
const ecdhSharedKey = Keypair.genEcdhSharedKey(
keypair.privKey,
keypair2.pubKey,
);
const plaintext = [...Array(1000).keys()];
const INPUT = {
'plaintext': plaintext.map(String),
'shared_key': ecdhSharedKey.toString(),
}
const witness = await circuit.calculateWitness(INPUT, true);
const ciphertext = {
iv: witness[1],
data: witness.slice(2,1002),
}
decryptedText = decrypt(ciphertext, ecdhSharedKey);
for (let i=0; i<1000; i++) {
assert(Fr.eq(Fr.e(decryptedText[i]), Fr.e(plaintext[i])));
}
});
it("encrypt multiple in js, decrypt in circom", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "decryptMultiple_test.circom"));
const keypair = new Keypair();
const keypair2 = new Keypair();
const ecdhSharedKey = Keypair.genEcdhSharedKey(
keypair.privKey,
keypair2.pubKey,
);
const plaintext = ([...Array(1000).keys()]).map(BigInt);
const ciphertext = encrypt(plaintext, ecdhSharedKey);
const INPUT = {
'message': [ciphertext.iv.toString(), ...ciphertext.data.map(String)],
'shared_key': ecdhSharedKey.toString(),
}
const witness = await circuit.calculateWitness(INPUT, true);
for (let i=0; i<1000; i++) {
assert(Fr.eq(Fr.e(witness[i+1]), Fr.e(plaintext[i])));
}
});
// TODO: encrypt a model
it("encrypt entire model in circom, decrypt in js", async () => {
});
});