mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-08 05:34:01 -05:00
Encrypt/decrypt multiple, Flatten2D into models
This commit is contained in:
5
test/circuits/decryptMultiple_test.circom
Normal file
5
test/circuits/decryptMultiple_test.circom
Normal file
@@ -0,0 +1,5 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/crypto/encrypt.circom";
|
||||
|
||||
component main = DecryptBits(1000);
|
||||
5
test/circuits/encryptMultiple_test.circom
Normal file
5
test/circuits/encryptMultiple_test.circom
Normal file
@@ -0,0 +1,5 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/crypto/encrypt.circom";
|
||||
|
||||
component main = EncryptBits(1000);
|
||||
@@ -5,6 +5,7 @@ include "../../circuits/Dense.circom";
|
||||
include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/Poly.circom";
|
||||
include "../../circuits/SumPooling2D.circom";
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
|
||||
template mnist_convnet() {
|
||||
signal input in[28][28][1];
|
||||
@@ -22,6 +23,7 @@ template mnist_convnet() {
|
||||
component conv2d_2 = Conv2D(13,13,4,8,3,1);
|
||||
component poly_2[11][11][8];
|
||||
component sum2d_2 = SumPooling2D(11,11,8,2,2);
|
||||
component flatten = Flatten2D(5,5,8);
|
||||
component dense = Dense(200,10);
|
||||
component argmax = ArgMax(10);
|
||||
|
||||
@@ -79,20 +81,21 @@ template mnist_convnet() {
|
||||
}
|
||||
}
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<5; i++) {
|
||||
for (var j=0; j<5; j++) {
|
||||
for (var k=0; k<8; k++) {
|
||||
dense.in[idx] <== sum2d_2.out[i][j][k];
|
||||
for (var m=0; m<10; m++) {
|
||||
dense.weights[idx][m] <== dense_weights[idx][m];
|
||||
}
|
||||
idx++;
|
||||
flatten.in[i][j][k] <== sum2d_2.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<200; i++) {
|
||||
dense.in[i] <== flatten.out[i];
|
||||
for (var j=0; j<10; j++) {
|
||||
dense.weights[i][j] <== dense_weights[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
dense.bias[i] <== dense_bias[i];
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/ReLU.circom";
|
||||
include "../../circuits/AveragePooling2D.circom";
|
||||
include "../../circuits/BatchNormalization2D.circom";
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
|
||||
template mnist_latest_optimized() {
|
||||
signal input in[28][28][1];
|
||||
@@ -29,6 +30,7 @@ template mnist_latest_optimized() {
|
||||
component bn_2 = BatchNormalization2D(11,11,8);
|
||||
component relu_2[11][11][8];
|
||||
component avg2d_2 = AveragePooling2D(11,11,8,2,2,25);
|
||||
component flatten = Flatten2D(5,5,8);
|
||||
component dense = Dense(200,10);
|
||||
component argmax = ArgMax(10);
|
||||
|
||||
@@ -107,20 +109,21 @@ template mnist_latest_optimized() {
|
||||
}
|
||||
}
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<5; i++) {
|
||||
for (var j=0; j<5; j++) {
|
||||
for (var k=0; k<8; k++) {
|
||||
dense.in[idx] <== avg2d_2.out[i][j][k];
|
||||
for (var m=0; m<10; m++) {
|
||||
dense.weights[idx][m] <== dense_weights[idx][m];
|
||||
}
|
||||
idx++;
|
||||
flatten.in[i][j][k] <== avg2d_2.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<200; i++) {
|
||||
dense.in[i] <== flatten.out[i];
|
||||
for (var j=0; j<10; j++) {
|
||||
dense.weights[i][j] <== dense_weights[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
dense.bias[i] <== dense_bias[i];
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/Poly.circom";
|
||||
include "../../circuits/AveragePooling2D.circom";
|
||||
include "../../circuits/BatchNormalization2D.circom";
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
|
||||
template mnist_latest() {
|
||||
signal input in[28][28][1];
|
||||
@@ -29,6 +30,7 @@ template mnist_latest() {
|
||||
component bn_2 = BatchNormalization2D(11,11,8);
|
||||
component poly_2[11][11][8];
|
||||
component avg2d_2 = AveragePooling2D(11,11,8,2,2,25);
|
||||
component flatten = Flatten2D(5,5,8);
|
||||
component dense = Dense(200,10);
|
||||
component argmax = ArgMax(10);
|
||||
|
||||
@@ -107,20 +109,21 @@ template mnist_latest() {
|
||||
}
|
||||
}
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<5; i++) {
|
||||
for (var j=0; j<5; j++) {
|
||||
for (var k=0; k<8; k++) {
|
||||
dense.in[idx] <== avg2d_2.out[i][j][k];
|
||||
for (var m=0; m<10; m++) {
|
||||
dense.weights[idx][m] <== dense_weights[idx][m];
|
||||
}
|
||||
idx++;
|
||||
flatten.in[i][j][k] <== avg2d_2.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<200; i++) {
|
||||
dense.in[i] <== flatten.out[i];
|
||||
for (var j=0; j<10; j++) {
|
||||
dense.weights[i][j] <== dense_weights[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
dense.bias[i] <== dense_bias[i];
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ include "../../circuits/Conv2D.circom";
|
||||
include "../../circuits/Dense.circom";
|
||||
include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/Poly.circom";
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
|
||||
template mnist_poly() {
|
||||
signal input in[28][28][1];
|
||||
@@ -14,6 +15,7 @@ template mnist_poly() {
|
||||
signal output out;
|
||||
|
||||
component conv2d = Conv2D(28,28,1,1,3,1);
|
||||
component flatten = Flatten2D(26,26,1);
|
||||
component poly[26*26];
|
||||
component dense = Dense(676,10);
|
||||
component argmax = ArgMax(10);
|
||||
@@ -32,17 +34,18 @@ template mnist_poly() {
|
||||
|
||||
conv2d.bias[0] <== conv2d_bias[0];
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<26; i++) {
|
||||
for (var j=0; j<26; j++) {
|
||||
poly[idx] = Poly(10**18);
|
||||
poly[idx].in <== conv2d.out[i][j][0];
|
||||
dense.in[idx] <== poly[idx].out;
|
||||
for (var k=0; k<10; k++) {
|
||||
dense.weights[idx][k] <== dense_weights[idx][k];
|
||||
}
|
||||
idx++;
|
||||
flatten.in[i][j][0] <== conv2d.out[i][j][0];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<26*26; i++) {
|
||||
poly[i] = Poly(10**18);
|
||||
poly[i].in <== flatten.out[i];
|
||||
dense.in[i] <== poly[i].out;
|
||||
for (var j=0; j<10; j++) {
|
||||
dense.weights[i][j] <== dense_weights[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ include "../../circuits/Conv2D.circom";
|
||||
include "../../circuits/Dense.circom";
|
||||
include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/ReLU.circom";
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
|
||||
template mnist() {
|
||||
signal input in[28][28][1];
|
||||
@@ -14,6 +15,7 @@ template mnist() {
|
||||
signal output out;
|
||||
|
||||
component conv2d = Conv2D(28,28,1,1,3,1);
|
||||
component flatten = Flatten2D(26,26,1);
|
||||
component relu[26*26];
|
||||
component dense = Dense(676,10);
|
||||
component argmax = ArgMax(10);
|
||||
@@ -32,17 +34,18 @@ template mnist() {
|
||||
|
||||
conv2d.bias[0] <== conv2d_bias[0];
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<26; i++) {
|
||||
for (var j=0; j<26; j++) {
|
||||
relu[idx] = ReLU();
|
||||
relu[idx].in <== conv2d.out[i][j][0];
|
||||
dense.in[idx] <== relu[idx].out;
|
||||
for (var k=0; k<10; k++) {
|
||||
dense.weights[idx][k] <== dense_weights[idx][k];
|
||||
}
|
||||
idx++;
|
||||
flatten.in[i][j][0] <== conv2d.out[i][j][0];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<26*26; i++) {
|
||||
relu[i] = ReLU();
|
||||
relu[i].in <== flatten.out[i];
|
||||
dense.in[i] <== relu[i].out;
|
||||
for (var j=0; j<10; j++) {
|
||||
dense.weights[i][j] <== dense_weights[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ describe("crypto circuits test", function () {
|
||||
assert(Fr.eq(Fr.e(witness[2]), Fr.e(keypair.pubKey.rawPubKey[1])));
|
||||
});
|
||||
|
||||
it("ecdh shared key test", async () => {
|
||||
it("ecdh full circuit test", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "ecdh_test.circom"));
|
||||
|
||||
const keypair = new Keypair();
|
||||
@@ -142,7 +142,69 @@ describe("crypto circuits test", function () {
|
||||
assert(Fr.eq(Fr.e(witness[1]), Fr.e(plaintext)));
|
||||
|
||||
});
|
||||
|
||||
// TODO: encrypt multiple
|
||||
// TODO: decrypt multiple
|
||||
|
||||
it("encrypt multiple in circom, decrypt in js", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "encryptMultiple_test.circom"));
|
||||
|
||||
const keypair = new Keypair();
|
||||
const keypair2 = new Keypair();
|
||||
|
||||
const ecdhSharedKey = Keypair.genEcdhSharedKey(
|
||||
keypair.privKey,
|
||||
keypair2.pubKey,
|
||||
);
|
||||
|
||||
const plaintext = [...Array(1000).keys()];
|
||||
|
||||
const INPUT = {
|
||||
'plaintext': plaintext.map(String),
|
||||
'shared_key': ecdhSharedKey.toString(),
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
const ciphertext = {
|
||||
iv: witness[1],
|
||||
data: witness.slice(2,1002),
|
||||
}
|
||||
|
||||
decryptedText = decrypt(ciphertext, ecdhSharedKey);
|
||||
|
||||
for (let i=0; i<1000; i++) {
|
||||
assert(Fr.eq(Fr.e(decryptedText[i]), Fr.e(plaintext[i])));
|
||||
}
|
||||
|
||||
});
|
||||
|
||||
it("encrypt multiple in js, decrypt in circom", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "decryptMultiple_test.circom"));
|
||||
|
||||
const keypair = new Keypair();
|
||||
const keypair2 = new Keypair();
|
||||
|
||||
const ecdhSharedKey = Keypair.genEcdhSharedKey(
|
||||
keypair.privKey,
|
||||
keypair2.pubKey,
|
||||
);
|
||||
|
||||
const plaintext = ([...Array(1000).keys()]).map(BigInt);
|
||||
|
||||
const ciphertext = encrypt(plaintext, ecdhSharedKey);
|
||||
|
||||
const INPUT = {
|
||||
'message': [ciphertext.iv.toString(), ...ciphertext.data.map(String)],
|
||||
'shared_key': ecdhSharedKey.toString(),
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
for (let i=0; i<1000; i++) {
|
||||
assert(Fr.eq(Fr.e(witness[i+1]), Fr.e(plaintext[i])));
|
||||
}
|
||||
|
||||
});
|
||||
|
||||
// TODO: encrypt a model
|
||||
it("encrypt entire model in circom, decrypt in js", async () => {
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user