From 9b903c4f868d0a8177a0f85faa24b39abc0e59db Mon Sep 17 00:00:00 2001 From: themighty1 Date: Mon, 8 Aug 2022 14:01:44 +0300 Subject: [PATCH] bin2ar works --- Cargo.toml | 2 + README | 8 +++ circuit.circom | 5 +- src/lib.rs | 98 +++++++++++++++++--------------- src/prover.rs | 98 ++++++++++++++++++++++++-------- src/verifier.rs | 147 +++++++++++++++++++++++++++++++++++++++++++++--- 6 files changed, 284 insertions(+), 74 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e2e6dff..b044ead 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,5 @@ name = "label_sum" num = { version = "0.4"} rand = "0.8.5" json = "0.12.4" +aes = { version = "0.7.5", features = [] } +cipher = "0.3" diff --git a/README b/README index 78ba3c9..7ef33fb 100644 --- a/README +++ b/README @@ -21,3 +21,11 @@ snarkjs groth16 fullprove input.json circuit_js/circuit.wasm circuit_final.zkey snarkjs groth16 verify verification_key.json public.json proof.json + + + +We can generate circuit.wasm and circuit.r1cs deterministically with circom 2.0.5+ +circom circuit.circom --r1cs --wasm +and then ship the User with .wasm and the Notary with .r1cs + +write a onetimesetup.js script and call it with node \ No newline at end of file diff --git a/circuit.circom b/circuit.circom index 8cb4d63..da39ae0 100644 --- a/circuit.circom +++ b/circuit.circom @@ -21,7 +21,9 @@ template Main() { for (var i = 0; i [u8; 16] { + let mut bytes = [0u8; 16]; + // leave high zero bytes untouched + bytes[16 - b.len()..].copy_from_slice(b); + bytes +} + #[cfg(test)] mod tests { - use crate::prover::{boolvec_to_u8vec, u8vec_to_boolvec}; + use crate::{ + onetimesetup::OneTimeSetup, + prover::{boolvec_to_u8vec, u8vec_to_boolvec}, + }; use super::*; use num::{BigUint, FromPrimitive}; @@ -32,57 +45,54 @@ mod tests { #[test] fn e2e_test() { + let prime = String::from(BN254_PRIME).parse::().unwrap(); let mut rng = thread_rng(); - // random 490 byte plaintext - // we have 16 FE * 253 bits each - 128 bits (salt) == 490 bytes + // OneTimeSetup is a no-op if the setup has been run before + let mut ots = OneTimeSetup::new(); + ots.setup().unwrap(); + + // random 490 byte plaintext. This is the size of one chunk. + // Our Poseidon is 16-arity * 253 bits each - 128 bits (salt) == 490 bytes let mut plaintext = [0u8; 512]; rng.fill(&mut plaintext); let plaintext = &plaintext[0..490]; - // bn254 prime 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 - // in decimal 21888242871839275222246405745257275088548364400416034343698204186575808495617 - // TODO: this will be done internally by verifier. But for now, until b2a converter is - // implemented, we do it here. - - // generate as many 128-bit arithm label pairs as there are plaintext bits. - // The 128-bit size is for convenience to be able to encrypt the label with 1 - // call to AES. - // To keep the handling simple, we want to avoid a negative delta, that's why - // W_0 and delta must be a 127-bit value and W_1 will be set to W_0 + delta - let bitsize = plaintext.len() * 8; - let mut zero_sum = BigUint::from_u8(0).unwrap(); - let mut deltas: Vec = Vec::with_capacity(bitsize); - let arithm_labels: Vec<[BigUint; 2]> = (0..bitsize) - .map(|_| { - let zero_label = random_bigint(127); - let delta = random_bigint(127); - let one_label = zero_label.clone() + delta.clone(); - zero_sum += zero_label.clone(); - deltas.push(delta); - [zero_label, one_label] - }) - .collect(); - - let prime = String::from(BN254_PRIME).parse::().unwrap(); - let mut prover = LsumProver::new(plaintext.to_vec(), prime); - let plaintext_hash = prover.setup(); - // Commitment to the plaintext is sent to the Notary - let mut verifier = LsumVerifier::new(); - verifier.receive_pt_hashes(plaintext_hash); - // Verifier sends back encrypted arithm. labels. We skip this step - // and simulate Prover's deriving his arithm labels: - let prover_labels = choose(&arithm_labels, &u8vec_to_boolvec(&plaintext)); - let mut label_sum = BigUint::from_u8(0).unwrap(); - for i in 0..prover_labels.len() { - label_sum += prover_labels[i].clone(); + // Normally, the Prover is expected to obtain her binary labels by + // evaluating the garbled circuit. + // To keep this test simple, we don't evaluate the gc, but we generate + // all labels of the Verifier and give the Prover her active labels. + let bit_size = plaintext.len() * 8; + let mut all_binary_labels: Vec<[u128; 2]> = Vec::with_capacity(bit_size); + let mut delta: u128 = rng.gen(); + // set the last bit + delta |= 1; + for _ in 0..bit_size { + let label_zero: u128 = rng.gen(); + all_binary_labels.push([label_zero, label_zero ^ delta]); } - // Prover sends a hash commitment to label_sum - let label_sum_hash = prover.poseidon(vec![label_sum.clone()]); - // Commitment to the label_sum is sent to the Notary - verifier.receive_labelsum_hash(label_sum_hash); + let prover_labels = choose(&all_binary_labels, &u8vec_to_boolvec(&plaintext)); + + let mut verifier = LsumVerifier::new(true); + // passing proving key to Prover (if he needs one) + let proving_key = verifier.get_proving_key().unwrap(); + // produce ciphertexts which are sent to Prover for decryption + verifier.setup(&all_binary_labels); + + let mut prover = LsumProver::new(prime); + prover.set_proving_key(proving_key); + let plaintext_hash = prover.setup(plaintext.to_vec()); + + // Commitment to the plaintext is sent to the Verifier + let cipheretexts = verifier.receive_pt_hashes(plaintext_hash); + // Verifier sends back encrypted arithm. labels. + + let label_sum_hash = prover.compute_label_sum(&cipheretexts, &prover_labels); + // Hash commitment to the label_sum is sent to the Notary + + let (deltas, zero_sum) = verifier.receive_labelsum_hash(label_sum_hash); // Notary sends zero_sum and all deltas // Prover constructs input to snarkjs - prover.create_zk_proof(zero_sum, deltas, label_sum); + prover.create_zk_proof(zero_sum, deltas); } } diff --git a/src/prover.rs b/src/prover.rs index 0236458..6aa8b1d 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -1,17 +1,24 @@ +use aes::{Aes128, BlockDecrypt, NewBlockCipher}; +use cipher::{consts::U16, generic_array::GenericArray, BlockCipher, BlockEncrypt}; use json::{object, stringify, stringify_pretty}; -/// implementing the "label sum" protocol -/// +use num::bigint::ToBigUint; use num::{BigUint, FromPrimitive, ToPrimitive, Zero}; use rand::{thread_rng, Rng}; use std::fs; use std::process::Command; use std::str; +#[derive(Debug)] +pub enum Error { + ProvingKeyNotFound, + FileSystemError, +} + use super::BN254_PRIME; -// implementation of the Prover in the "label sum" protocol (aka the User). +// implementation of the Prover in the "label_sum" protocol (aka the User). pub struct LsumProver { - plaintext: Vec, + plaintext: Option>, // the prime of the field in which Poseidon hash will be computed. field_prime: BigUint, // how many bits to pack into one field element @@ -25,10 +32,12 @@ pub struct LsumProver { // security b/c w/o the salt, hashes of plaintext with low entropy could be // brute-forced. salts: Option>, + // hash of all our arithmetic labels + label_sum_hash: Option, } impl LsumProver { - pub fn new(plaintext: Vec, field_prime: BigUint) -> Self { + pub fn new(field_prime: BigUint) -> Self { if field_prime.bits() < 129 { // last field element must be large enough to contain the 128-bit // salt. In the future, if we need to support fields < 129 bits, @@ -36,18 +45,28 @@ impl LsumProver { panic!("Error: expected a prime >= 129 bits"); } Self { - plaintext, + plaintext: None, field_prime, useful_bits: None, chunks: None, salts: None, hashes_of_chunks: None, + label_sum_hash: None, } } + pub fn set_proving_key(&mut self, key: Vec) -> Result<(), Error> { + let res = fs::write("circuit_final.zkey", key); + if res.is_err() { + return Err(Error::FileSystemError); + } + Ok(()) + } + // Return hash digests which is Prover's commitment to the plaintext - pub fn setup(&mut self) -> Vec { - let useful_bits = compute_useful_bits(self.field_prime.clone()); + pub fn setup(&mut self, plaintext: Vec) -> Vec { + self.plaintext = Some(plaintext); + let useful_bits = calculate_useful_bits(self.field_prime.clone()); self.useful_bits = Some(useful_bits); let (chunks, salts) = self.plaintext_to_chunks(); self.chunks = Some(chunks.clone()); @@ -56,6 +75,41 @@ impl LsumProver { self.hashes_of_chunks.as_ref().unwrap().to_vec() } + // decrypt each encrypted arithm.label based on the p&p bit of our active + // binary label. Return the hash of the sum of all arithm. labels. + pub fn compute_label_sum( + &mut self, + ciphertexts: &Vec<[Vec; 2]>, + labels: &Vec, + ) -> BigUint { + // if binary label's p&p bit is 0, decrypt the 1st ciphertext, + // otherwise - the 2nd one. + assert!(ciphertexts.len() == labels.len()); + let mut label_sum = BigUint::from_u8(0).unwrap(); + let _unused: Vec<()> = ciphertexts + .iter() + .zip(labels) + .map(|(ct_pair, label)| { + let key = Aes128::new_from_slice(&label.to_be_bytes()).unwrap(); + // choose which ciphertext to decrypt based on the point-and-permute bit + let mut ct = [0u8; 16]; + if label & 1 == 0 { + ct.copy_from_slice(&ct_pair[0]); + } else { + ct.copy_from_slice(&ct_pair[1]); + } + let mut ct: GenericArray = GenericArray::from(ct); + key.decrypt_block(&mut ct); + // add the decrypted arithmetic label to the sum + label_sum += BigUint::from_bytes_be(&ct); + }) + .collect(); + println!("{:?} label_sum", label_sum); + let label_sum_hash = self.poseidon(vec![label_sum]); + self.label_sum_hash = Some(label_sum_hash.clone()); + label_sum_hash + } + // create chunks of plaintext where each chunk consists of 16 field elements. // The last element's last 128 bits are reserved for the salt of the hash. // If there is not enough plaintext to fill the whole chunk, we fill the gap @@ -67,7 +121,7 @@ impl LsumProver { //let chunk_size = useful_bits * 16; // plaintext converted into bits - let mut bits = u8vec_to_boolvec(&self.plaintext); + let mut bits = u8vec_to_boolvec(&self.plaintext.as_ref().unwrap()); // chunk count (rounded up) let chunk_count = (bits.len() + (chunk_size - 1)) / chunk_size; // extend bits with zeroes to fill the chunk @@ -97,8 +151,7 @@ impl LsumProver { BigUint::default(), BigUint::default(), ]; - // TODO dont use salt for now, to make debugging easier, later change this to - // for j in 0..15 { and uncomment the lines below //offset and //chunk[15] + for j in 0..15 { // convert bits into field element chunk[j] = @@ -148,6 +201,7 @@ impl LsumProver { .args(["poseidon.mjs", &json]) .output() .unwrap(); + println!("{:?}", output); // drop the trailing new line let output = &output.stdout[0..output.stdout.len() - 1]; let s = String::from_utf8(output.to_vec()).unwrap(); @@ -156,9 +210,9 @@ impl LsumProver { bi } - pub fn create_zk_proof(&mut self, zero_sum: BigUint, deltas: Vec, label_sum: BigUint) { + pub fn create_zk_proof(&mut self, zero_sum: BigUint, deltas: Vec) { // hash label_sum - let label_sum_hash = self.poseidon(vec![label_sum.clone()]); + let label_sum_hash = self.label_sum_hash.as_ref().unwrap().clone(); // write inputs into input.json let pt_str: Vec = self.chunks.as_ref().unwrap()[0] @@ -181,8 +235,8 @@ impl LsumProver { let delta_last_str: Vec = delta_last.iter().map(|v| v.to_string()).collect(); let mut data = object! { - label_sum_hash: label_sum_hash.to_string(), plaintext_hash: self.hashes_of_chunks.as_ref().unwrap()[0].to_string(), + label_sum_hash: label_sum_hash.to_string(), sum_of_zero_labels: zero_sum.to_string(), plaintext: pt_str, delta: delta_str, @@ -193,18 +247,18 @@ impl LsumProver { } } -/// Computes how many bits of plaintext we will pack into one field element. +/// Calculates how many bits of plaintext we will pack into one field element. /// Essentially, this is field_prime bit length minus 1. -fn compute_useful_bits(field_prime: BigUint) -> usize { +fn calculate_useful_bits(field_prime: BigUint) -> usize { (field_prime.bits() - 1) as usize } #[test] fn test_compute_full_bits() { - assert_eq!(compute_useful_bits(BigUint::from_u16(13).unwrap()), 3); - assert_eq!(compute_useful_bits(BigUint::from_u16(255).unwrap()), 7); + assert_eq!(calculate_useful_bits(BigUint::from_u16(13).unwrap()), 3); + assert_eq!(calculate_useful_bits(BigUint::from_u16(255).unwrap()), 7); assert_eq!( - compute_useful_bits(String::from(BN254_PRIME,).parse::().unwrap()), + calculate_useful_bits(String::from(BN254_PRIME,).parse::().unwrap()), 253 ); } @@ -265,8 +319,8 @@ mod tests { plaintext[(15 * 17 + 1) + i * 17 + 16] = (i + 16) as u8; } - let mut prover = LsumProver::new(plaintext, prime); - prover.setup(); + let mut prover = LsumProver::new(prime); + prover.setup(plaintext); // Check chunk1 correctness let chunk1: Vec = prover.chunks.clone().unwrap()[0][0..15] @@ -298,7 +352,7 @@ mod tests { let prime = BigUint::from_bytes_be(&prime); // plaintext will spawn 2 chunks let mut plaintext = vec![0u8; 17 * 15 + 1 + 17 * 5]; - let mut prover = LsumProver::new(plaintext, prime); + let mut prover = LsumProver::new(prime); //LsumProver::hash_chunks(&mut prover); } diff --git a/src/verifier.rs b/src/verifier.rs index d6efaf3..81e60f1 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -1,28 +1,161 @@ +use super::to_16_bytes; +use aes::{Aes128, NewBlockCipher}; +use cipher::{consts::U16, generic_array::GenericArray, BlockCipher, BlockEncrypt}; use num::{BigUint, FromPrimitive, ToPrimitive, Zero}; +use rand::{thread_rng, Rng}; +use std::fs; +use std::path::Path; + +#[derive(Debug)] +pub enum Error { + ProvingKeyNotFound, + FileSystemError, +} // implementation of the Verifier in the "label sum" protocol (aka the Notary). pub struct LsumVerifier { // hashes for each chunk of Prover's plaintext plaintext_hashes: Option>, labelsum_hash: Option, + // if set to true, then we must send the proving key to the Prover + // before this protocol begins. Otherwise, it is assumed that the Prover + // already has the proving key from a previous interaction with us. + proving_key_needed: bool, + deltas: Option>, + zero_sum: Option, + ciphertexts: Option; 2]>>, } impl LsumVerifier { - pub fn new() -> Self { + pub fn new(proving_key_needed: bool) -> Self { Self { plaintext_hashes: None, labelsum_hash: None, + proving_key_needed, + deltas: None, + zero_sum: None, + ciphertexts: None, } } - // receive hashes of plaintext and reveal the arithmetic labels - pub fn receive_pt_hashes(&mut self, hashes: Vec) { - self.plaintext_hashes = Some(hashes); - // TODO at this stage we send 2 ciphertexts (encrypted arithm. labels), - // only 1 of which the User can decrypt + pub fn get_proving_key(&mut self) -> Result, Error> { + if !Path::new("circuit_final.zkey").exists() { + return Err(Error::ProvingKeyNotFound); + } + let res = fs::read("circuit_final.zkey"); + if res.is_err() { + return Err(Error::FileSystemError); + } + Ok(res.unwrap()) } - pub fn receive_labelsum_hash(&mut self, hash: BigUint) { + // Convert binary labels into encrypted arithmetic labels. + // Prepare JSON objects to be converted into proof.json before verification + pub fn setup(&mut self, labels: &Vec<[u128; 2]>) { + // generate as many 128-bit arithm label pairs as there are plaintext bits. + // The 128-bit size is for convenience to be able to encrypt the label with 1 + // call to AES. + // To keep the handling simple, we want to avoid a negative delta, that's why + // W_0 and delta must be 127-bit values and W_1 will be set to W_0 + delta + let bitsize = labels.len(); + let mut zero_sum = BigUint::from_u8(0).unwrap(); + let mut deltas: Vec = Vec::with_capacity(bitsize); + let arithm_labels: Vec<[BigUint; 2]> = (0..bitsize) + .map(|_| { + let zero_label = random_bigint(127); + let delta = random_bigint(127); + let one_label = zero_label.clone() + delta.clone(); + zero_sum += zero_label.clone(); + deltas.push(delta); + [zero_label, one_label] + }) + .collect(); + self.zero_sum = Some(zero_sum); + self.deltas = Some(deltas); + + // encrypt each arithmetic label using a corresponding binary label as a key + // place ciphertexts in an order based on binary label's p&p bit + let ciphertexts: Vec<[Vec; 2]> = labels + .iter() + .zip(arithm_labels) + .map(|(bin_pair, arithm_pair)| { + let zero_key = Aes128::new_from_slice(&bin_pair[0].to_be_bytes()).unwrap(); + let one_key = Aes128::new_from_slice(&bin_pair[1].to_be_bytes()).unwrap(); + let mut label0 = [0u8; 16]; + let mut label1 = [0u8; 16]; + //println!("{:?}", arithm_pair[0].to_bytes_be()); + //println!("{:?}", arithm_pair[1].to_bytes_be()); + let ap0 = arithm_pair[0].to_bytes_be(); + let ap1 = arithm_pair[1].to_bytes_be(); + label0[16 - ap0.len()..].copy_from_slice(&ap0); + label1[16 - ap1.len()..].copy_from_slice(&ap1); + let mut label0: GenericArray = GenericArray::from(label0); + let mut label1: GenericArray = GenericArray::from(label1); + zero_key.encrypt_block(&mut label0); + one_key.encrypt_block(&mut label1); + // ciphertext 0 and ciphertext 1 + let ct0 = label0.to_vec(); + let ct1 = label1.to_vec(); + // get point and permute bit of binary label 0 + if (bin_pair[0] & 1) == 0 { + [ct0, ct1] + } else { + [ct1, ct0] + } + }) + .collect(); + self.ciphertexts = Some(ciphertexts); + } + + // receive hashes of plaintext and reveal the encrypted arithmetic labels + pub fn receive_pt_hashes(&mut self, hashes: Vec) -> Vec<[Vec; 2]> { + self.plaintext_hashes = Some(hashes); + self.ciphertexts.as_ref().unwrap().clone() + } + + // receive the hash commitment to the Prover's sum of labels and reveal all + // deltas and zero_sum. + pub fn receive_labelsum_hash(&mut self, hash: BigUint) -> (Vec, BigUint) { self.labelsum_hash = Some(hash); + ( + self.deltas.as_ref().unwrap().clone(), + self.zero_sum.as_ref().unwrap().clone(), + ) } } + +fn random_bigint(bitsize: usize) -> BigUint { + assert!(bitsize <= 128); + let r: [u8; 16] = thread_rng().gen(); + // take only those bits which we need + BigUint::from_bytes_be(&boolvec_to_u8vec(&u8vec_to_boolvec(&r)[0..bitsize])) +} + +#[inline] +pub fn u8vec_to_boolvec(v: &[u8]) -> Vec { + let mut bv = Vec::with_capacity(v.len() * 8); + for byte in v.iter() { + for i in 0..8 { + bv.push(((byte >> (7 - i)) & 1) != 0); + } + } + bv +} + +// Convert bits into bytes. The bits will be left-padded with zeroes to the +// multiple of 8. +#[inline] +pub fn boolvec_to_u8vec(bv: &[bool]) -> Vec { + let rem = bv.len() % 8; + let first_byte_bitsize = if rem == 0 { 8 } else { rem }; + let offset = if rem == 0 { 0 } else { 1 }; + let mut v = vec![0u8; bv.len() / 8 + offset]; + // implicitely left-pad the first byte with zeroes + for (i, b) in bv[0..first_byte_bitsize].iter().enumerate() { + v[i / 8] |= (*b as u8) << (first_byte_bitsize - 1 - i); + } + for (i, b) in bv[first_byte_bitsize..].iter().enumerate() { + v[1 + i / 8] |= (*b as u8) << (7 - (i % 8)); + } + v +}