From 822c271259f2d63182921f7ecf47c98e2436b737 Mon Sep 17 00:00:00 2001 From: Cathie So Date: Wed, 30 Nov 2022 00:04:30 +0800 Subject: [PATCH] Encrypt/decrypt multiple, Flatten2D into models --- test/circuits/decryptMultiple_test.circom | 5 ++ test/circuits/encryptMultiple_test.circom | 5 ++ test/circuits/mnist_convnet_test.circom | 17 +++-- .../mnist_latest_optimized_test.circom | 17 +++-- test/circuits/mnist_latest_test.circom | 17 +++-- test/circuits/mnist_poly_test.circom | 21 +++--- test/circuits/mnist_test.circom | 21 +++--- test/encryption.js | 70 +++++++++++++++++-- 8 files changed, 130 insertions(+), 43 deletions(-) create mode 100644 test/circuits/decryptMultiple_test.circom create mode 100644 test/circuits/encryptMultiple_test.circom diff --git a/test/circuits/decryptMultiple_test.circom b/test/circuits/decryptMultiple_test.circom new file mode 100644 index 0000000..b81ae07 --- /dev/null +++ b/test/circuits/decryptMultiple_test.circom @@ -0,0 +1,5 @@ +pragma circom 2.0.3; + +include "../../circuits/crypto/encrypt.circom"; + +component main = DecryptBits(1000); \ No newline at end of file diff --git a/test/circuits/encryptMultiple_test.circom b/test/circuits/encryptMultiple_test.circom new file mode 100644 index 0000000..6e1edeb --- /dev/null +++ b/test/circuits/encryptMultiple_test.circom @@ -0,0 +1,5 @@ +pragma circom 2.0.3; + +include "../../circuits/crypto/encrypt.circom"; + +component main = EncryptBits(1000); \ No newline at end of file diff --git a/test/circuits/mnist_convnet_test.circom b/test/circuits/mnist_convnet_test.circom index 3770140..1f532e4 100644 --- a/test/circuits/mnist_convnet_test.circom +++ b/test/circuits/mnist_convnet_test.circom @@ -5,6 +5,7 @@ include "../../circuits/Dense.circom"; include "../../circuits/ArgMax.circom"; include "../../circuits/Poly.circom"; include "../../circuits/SumPooling2D.circom"; +include "../../circuits/Flatten2D.circom"; template mnist_convnet() { signal input in[28][28][1]; @@ -22,6 +23,7 @@ template mnist_convnet() { component conv2d_2 = Conv2D(13,13,4,8,3,1); component poly_2[11][11][8]; component sum2d_2 = SumPooling2D(11,11,8,2,2); + component flatten = Flatten2D(5,5,8); component dense = Dense(200,10); component argmax = ArgMax(10); @@ -79,20 +81,21 @@ template mnist_convnet() { } } - var idx = 0; - for (var i=0; i<5; i++) { for (var j=0; j<5; j++) { for (var k=0; k<8; k++) { - dense.in[idx] <== sum2d_2.out[i][j][k]; - for (var m=0; m<10; m++) { - dense.weights[idx][m] <== dense_weights[idx][m]; - } - idx++; + flatten.in[i][j][k] <== sum2d_2.out[i][j][k]; } } } + for (var i=0; i<200; i++) { + dense.in[i] <== flatten.out[i]; + for (var j=0; j<10; j++) { + dense.weights[i][j] <== dense_weights[i][j]; + } + } + for (var i=0; i<10; i++) { dense.bias[i] <== dense_bias[i]; } diff --git a/test/circuits/mnist_latest_optimized_test.circom b/test/circuits/mnist_latest_optimized_test.circom index 6002223..efa2b74 100644 --- a/test/circuits/mnist_latest_optimized_test.circom +++ b/test/circuits/mnist_latest_optimized_test.circom @@ -6,6 +6,7 @@ include "../../circuits/ArgMax.circom"; include "../../circuits/ReLU.circom"; include "../../circuits/AveragePooling2D.circom"; include "../../circuits/BatchNormalization2D.circom"; +include "../../circuits/Flatten2D.circom"; template mnist_latest_optimized() { signal input in[28][28][1]; @@ -29,6 +30,7 @@ template mnist_latest_optimized() { component bn_2 = BatchNormalization2D(11,11,8); component relu_2[11][11][8]; component avg2d_2 = AveragePooling2D(11,11,8,2,2,25); + component flatten = Flatten2D(5,5,8); component dense = Dense(200,10); component argmax = ArgMax(10); @@ -107,20 +109,21 @@ template mnist_latest_optimized() { } } - var idx = 0; - for (var i=0; i<5; i++) { for (var j=0; j<5; j++) { for (var k=0; k<8; k++) { - dense.in[idx] <== avg2d_2.out[i][j][k]; - for (var m=0; m<10; m++) { - dense.weights[idx][m] <== dense_weights[idx][m]; - } - idx++; + flatten.in[i][j][k] <== avg2d_2.out[i][j][k]; } } } + for (var i=0; i<200; i++) { + dense.in[i] <== flatten.out[i]; + for (var j=0; j<10; j++) { + dense.weights[i][j] <== dense_weights[i][j]; + } + } + for (var i=0; i<10; i++) { dense.bias[i] <== dense_bias[i]; } diff --git a/test/circuits/mnist_latest_test.circom b/test/circuits/mnist_latest_test.circom index b7dae35..51a39dd 100644 --- a/test/circuits/mnist_latest_test.circom +++ b/test/circuits/mnist_latest_test.circom @@ -6,6 +6,7 @@ include "../../circuits/ArgMax.circom"; include "../../circuits/Poly.circom"; include "../../circuits/AveragePooling2D.circom"; include "../../circuits/BatchNormalization2D.circom"; +include "../../circuits/Flatten2D.circom"; template mnist_latest() { signal input in[28][28][1]; @@ -29,6 +30,7 @@ template mnist_latest() { component bn_2 = BatchNormalization2D(11,11,8); component poly_2[11][11][8]; component avg2d_2 = AveragePooling2D(11,11,8,2,2,25); + component flatten = Flatten2D(5,5,8); component dense = Dense(200,10); component argmax = ArgMax(10); @@ -107,20 +109,21 @@ template mnist_latest() { } } - var idx = 0; - for (var i=0; i<5; i++) { for (var j=0; j<5; j++) { for (var k=0; k<8; k++) { - dense.in[idx] <== avg2d_2.out[i][j][k]; - for (var m=0; m<10; m++) { - dense.weights[idx][m] <== dense_weights[idx][m]; - } - idx++; + flatten.in[i][j][k] <== avg2d_2.out[i][j][k]; } } } + for (var i=0; i<200; i++) { + dense.in[i] <== flatten.out[i]; + for (var j=0; j<10; j++) { + dense.weights[i][j] <== dense_weights[i][j]; + } + } + for (var i=0; i<10; i++) { dense.bias[i] <== dense_bias[i]; } diff --git a/test/circuits/mnist_poly_test.circom b/test/circuits/mnist_poly_test.circom index a998fca..577c8d9 100644 --- a/test/circuits/mnist_poly_test.circom +++ b/test/circuits/mnist_poly_test.circom @@ -4,6 +4,7 @@ include "../../circuits/Conv2D.circom"; include "../../circuits/Dense.circom"; include "../../circuits/ArgMax.circom"; include "../../circuits/Poly.circom"; +include "../../circuits/Flatten2D.circom"; template mnist_poly() { signal input in[28][28][1]; @@ -14,6 +15,7 @@ template mnist_poly() { signal output out; component conv2d = Conv2D(28,28,1,1,3,1); + component flatten = Flatten2D(26,26,1); component poly[26*26]; component dense = Dense(676,10); component argmax = ArgMax(10); @@ -32,17 +34,18 @@ template mnist_poly() { conv2d.bias[0] <== conv2d_bias[0]; - var idx = 0; - for (var i=0; i<26; i++) { for (var j=0; j<26; j++) { - poly[idx] = Poly(10**18); - poly[idx].in <== conv2d.out[i][j][0]; - dense.in[idx] <== poly[idx].out; - for (var k=0; k<10; k++) { - dense.weights[idx][k] <== dense_weights[idx][k]; - } - idx++; + flatten.in[i][j][0] <== conv2d.out[i][j][0]; + } + } + + for (var i=0; i<26*26; i++) { + poly[i] = Poly(10**18); + poly[i].in <== flatten.out[i]; + dense.in[i] <== poly[i].out; + for (var j=0; j<10; j++) { + dense.weights[i][j] <== dense_weights[i][j]; } } diff --git a/test/circuits/mnist_test.circom b/test/circuits/mnist_test.circom index d2ff66a..2a72dc5 100644 --- a/test/circuits/mnist_test.circom +++ b/test/circuits/mnist_test.circom @@ -4,6 +4,7 @@ include "../../circuits/Conv2D.circom"; include "../../circuits/Dense.circom"; include "../../circuits/ArgMax.circom"; include "../../circuits/ReLU.circom"; +include "../../circuits/Flatten2D.circom"; template mnist() { signal input in[28][28][1]; @@ -14,6 +15,7 @@ template mnist() { signal output out; component conv2d = Conv2D(28,28,1,1,3,1); + component flatten = Flatten2D(26,26,1); component relu[26*26]; component dense = Dense(676,10); component argmax = ArgMax(10); @@ -32,17 +34,18 @@ template mnist() { conv2d.bias[0] <== conv2d_bias[0]; - var idx = 0; - for (var i=0; i<26; i++) { for (var j=0; j<26; j++) { - relu[idx] = ReLU(); - relu[idx].in <== conv2d.out[i][j][0]; - dense.in[idx] <== relu[idx].out; - for (var k=0; k<10; k++) { - dense.weights[idx][k] <== dense_weights[idx][k]; - } - idx++; + flatten.in[i][j][0] <== conv2d.out[i][j][0]; + } + } + + for (var i=0; i<26*26; i++) { + relu[i] = ReLU(); + relu[i].in <== flatten.out[i]; + dense.in[i] <== relu[i].out; + for (var j=0; j<10; j++) { + dense.weights[i][j] <== dense_weights[i][j]; } } diff --git a/test/encryption.js b/test/encryption.js index 63d8ce8..04584fa 100644 --- a/test/encryption.js +++ b/test/encryption.js @@ -34,7 +34,7 @@ describe("crypto circuits test", function () { assert(Fr.eq(Fr.e(witness[2]), Fr.e(keypair.pubKey.rawPubKey[1]))); }); - it("ecdh shared key test", async () => { + it("ecdh full circuit test", async () => { const circuit = await wasm_tester(path.join(__dirname, "circuits", "ecdh_test.circom")); const keypair = new Keypair(); @@ -142,7 +142,69 @@ describe("crypto circuits test", function () { assert(Fr.eq(Fr.e(witness[1]), Fr.e(plaintext))); }); - - // TODO: encrypt multiple - // TODO: decrypt multiple + + it("encrypt multiple in circom, decrypt in js", async () => { + const circuit = await wasm_tester(path.join(__dirname, "circuits", "encryptMultiple_test.circom")); + + const keypair = new Keypair(); + const keypair2 = new Keypair(); + + const ecdhSharedKey = Keypair.genEcdhSharedKey( + keypair.privKey, + keypair2.pubKey, + ); + + const plaintext = [...Array(1000).keys()]; + + const INPUT = { + 'plaintext': plaintext.map(String), + 'shared_key': ecdhSharedKey.toString(), + } + + const witness = await circuit.calculateWitness(INPUT, true); + + const ciphertext = { + iv: witness[1], + data: witness.slice(2,1002), + } + + decryptedText = decrypt(ciphertext, ecdhSharedKey); + + for (let i=0; i<1000; i++) { + assert(Fr.eq(Fr.e(decryptedText[i]), Fr.e(plaintext[i]))); + } + + }); + + it("encrypt multiple in js, decrypt in circom", async () => { + const circuit = await wasm_tester(path.join(__dirname, "circuits", "decryptMultiple_test.circom")); + + const keypair = new Keypair(); + const keypair2 = new Keypair(); + + const ecdhSharedKey = Keypair.genEcdhSharedKey( + keypair.privKey, + keypair2.pubKey, + ); + + const plaintext = ([...Array(1000).keys()]).map(BigInt); + + const ciphertext = encrypt(plaintext, ecdhSharedKey); + + const INPUT = { + 'message': [ciphertext.iv.toString(), ...ciphertext.data.map(String)], + 'shared_key': ecdhSharedKey.toString(), + } + + const witness = await circuit.calculateWitness(INPUT, true); + + for (let i=0; i<1000; i++) { + assert(Fr.eq(Fr.e(witness[i+1]), Fr.e(plaintext[i]))); + } + + }); + + // TODO: encrypt a model + it("encrypt entire model in circom, decrypt in js", async () => { + }); }); \ No newline at end of file