wip commit to figure out lifetimes

This commit is contained in:
themighty1
2022-08-17 17:47:02 +03:00
parent 63d85ee278
commit 0abb37b06b
2 changed files with 156 additions and 104 deletions

View File

@@ -1,5 +1,5 @@
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
use prover::ProverError;
use prover::{ProverCore, ProverError};
use verifier::VerifierError;
pub mod onetimesetup;
@@ -11,8 +11,9 @@ pub mod verifier;
const BN254_PRIME: &str =
"21888242871839275222246405745257275088548364400416034343698204186575808495617";
// ProverCore must be implemented by the nodejs and wasm backends
pub trait ProverCore {
// ProverVirtual describes virtual methods which must be implemented by the
// nodejs and wasm implementors
pub trait ProverVirtual {
fn set_proving_key(&mut self, key: Vec<u8>) -> Result<(), ProverError>;
fn poseidon(&mut self, inputs: Vec<BigUint>) -> BigUint;
@@ -20,6 +21,32 @@ pub trait ProverCore {
fn prove(&mut self, input: String) -> Result<Vec<u8>, ProverError>;
}
// provides default implementations for methods which must passthrough to the
// LsumProverCore. Only get_core() must be implemented.
pub trait ProverPassthrough {
fn setup(&mut self, plaintext: Vec<u8>) -> Vec<BigUint> {
self.get_core().setup(plaintext)
}
fn compute_label_sum(
&mut self,
ciphertexts: &Vec<[Vec<u8>; 2]>,
labels: &Vec<u128>,
) -> Vec<BigUint> {
self.get_core().compute_label_sum(ciphertexts, labels)
}
fn create_zk_proof(
&mut self,
zero_sum: Vec<BigUint>,
mut deltas: Vec<BigUint>,
) -> Result<Vec<Vec<u8>>, ProverError> {
self.get_core().create_zk_proof(zero_sum, deltas)
}
fn get_core(&mut self) -> &ProverCore<Box<Self>>;
}
pub trait VerifierCore {
fn get_proving_key(&mut self) -> Result<Vec<u8>, VerifierError>;
@@ -42,7 +69,7 @@ mod tests {
use super::*;
use num::{BigUint, FromPrimitive};
use prover::LsumProver;
use prover::Prover;
use rand::{thread_rng, Rng, RngCore};
use verifier::LsumVerifier;
@@ -99,7 +126,7 @@ mod tests {
// produce ciphertexts which are sent to Prover for decryption
verifier.setup(&all_binary_labels);
let mut prover = LsumProver::new(prime);
let mut prover = Prover::new(prime);
prover.set_proving_key(proving_key);
let plaintext_hash = prover.setup(plaintext.to_vec());

View File

@@ -16,7 +16,7 @@ pub enum ProverError {
SnarkjsError,
}
use crate::ProverCore;
use crate::{ProverPassthrough, ProverVirtual};
use super::BN254_PRIME;
@@ -30,30 +30,28 @@ fn check_output(output: Result<Output, std::io::Error>) -> Result<(), ProverErro
Ok(())
}
// implementation of the Prover in the "label_sum" protocol (aka the User).
pub struct LsumProver {
// bytes of the plaintext which was obtained from the garbled circuit
plaintext: Option<Vec<u8>>,
// the prime of the field in which Poseidon hash will be computed.
field_prime: BigUint,
// how many bits to pack into one field element
useful_bits: Option<usize>,
// the size of one chunk == useful_bits * Poseidon_width - 128 (salt size)
chunk_size: Option<usize>,
// We will compute a separate Poseidon hash on each chunk of the plaintext.
// Each chunk contains 16 field elements.
chunks: Option<Vec<[BigUint; 16]>>,
// Poseidon hashes of each chunk
hashes_of_chunks: Option<Vec<BigUint>>,
// each chunk's last 128 bits are used for the salt. This is important for
// security b/c w/o the salt, hashes of plaintext with low entropy could be
// brute-forced.
salts: Option<Vec<BigUint>>,
// hash of all our arithmetic labels
label_sum_hashes: Option<Vec<BigUint>>,
pub struct Prover<'a> {
core: ProverCore<&'a Prover>,
}
impl ProverCore for LsumProver {
impl Prover {
pub fn new(field_prime: BigUint) -> Self {
let core = ProverCore::new(field_prime, Self);
Self { core }
}
pub fn setup(&mut self, plaintext: Vec<u8>) -> Vec<BigUint> {
self.core.setup(plaintext)
}
}
impl ProverPassthrough for Prover {
fn get_core(&mut self) -> &ProverCore<Box<Prover>> {
&self.core
}
}
impl ProverVirtual for Prover {
fn set_proving_key(&mut self, key: Vec<u8>) -> Result<(), ProverError> {
let res = fs::write("circuit_final.zkey.verifier", key);
if res.is_err() {
@@ -108,8 +106,31 @@ impl ProverCore for LsumProver {
}
}
impl LsumProver {
pub fn new(field_prime: BigUint) -> Self {
// implementation of the Prover in the "label_sum" protocol (aka the User).
pub struct ProverCore<'a, T: 'a> {
// bytes of the plaintext which was obtained from the garbled circuit
plaintext: Option<Vec<u8>>,
// the prime of the field in which Poseidon hash will be computed.
field_prime: BigUint,
// how many bits to pack into one field element
useful_bits: Option<usize>,
// the size of one chunk == useful_bits * Poseidon_width - 128 (salt size)
chunk_size: Option<usize>,
// We will compute a separate Poseidon hash on each chunk of the plaintext.
// Each chunk contains 16 field elements.
chunks: Option<Vec<[BigUint; 16]>>,
// Poseidon hashes of each chunk
hashes_of_chunks: Option<Vec<BigUint>>,
// each chunk's last 128 bits are used for the salt. w/o the salt, hashes
// of plaintext with low entropy could be brute-forced.
salts: Option<Vec<BigUint>>,
// hash of all our arithmetic labels
label_sum_hashes: Option<Vec<BigUint>>,
caller: &'a T,
}
impl<T> ProverCore<T> {
pub fn new(field_prime: BigUint, caller: &T) -> Self {
if field_prime.bits() < 129 {
// last field element must be large enough to contain the 128-bit
// salt. In the future, if we need to support fields < 129 bits,
@@ -125,17 +146,19 @@ impl LsumProver {
hashes_of_chunks: None,
label_sum_hashes: None,
chunk_size: None,
caller,
}
}
// Return hash digests which is Prover's commitment to the plaintext
pub fn setup(&mut self, plaintext: Vec<u8>) -> Vec<BigUint> {
self.plaintext = Some(plaintext);
self.plaintext = Some(plaintext.clone());
let useful_bits = calculate_useful_bits(self.field_prime.clone());
self.useful_bits = Some(useful_bits);
let (chunks, salts) = self.plaintext_to_chunks();
let (chunk_size, chunks, salts) = self.plaintext_to_chunks(useful_bits, plaintext);
self.chunks = Some(chunks.clone());
self.salts = Some(salts);
self.chunk_size = Some(chunk_size);
let hashes = self.hash_chunks(chunks);
self.hashes_of_chunks = Some(hashes.clone());
hashes
@@ -150,7 +173,7 @@ impl LsumProver {
labels: &Vec<u128>,
) -> Vec<BigUint> {
// if binary label's p&p bit is 0, decrypt the 1st ciphertext,
// otherwise - the 2nd one.
// otherwise decrypt the 2nd one.
assert!(ciphertexts.len() == labels.len());
assert!(self.plaintext.as_ref().unwrap().len() * 8 == ciphertexts.len());
let mut label_sum_hashes: Vec<BigUint> =
@@ -185,75 +208,6 @@ impl LsumProver {
label_sum_hashes
}
// create chunks of plaintext where each chunk consists of 16 field elements.
// The last element's last 128 bits are reserved for the salt of the hash.
// If there is not enough plaintext to fill the whole chunk, we fill the gap
// with zero bits.
fn plaintext_to_chunks(&mut self) -> (Vec<[BigUint; 16]>, Vec<BigUint>) {
let useful_bits = self.useful_bits.unwrap();
// the size of a chunk of plaintext not counting the salt
let chunk_size = useful_bits * 16 - 128;
self.chunk_size = Some(chunk_size);
// plaintext converted into bits
let mut bits = u8vec_to_boolvec(&self.plaintext.as_ref().unwrap());
// chunk count (rounded up)
let chunk_count = (bits.len() + (chunk_size - 1)) / chunk_size;
// extend bits with zeroes to fill the last chunk
bits.extend(vec![false; chunk_count * chunk_size - bits.len()]);
let mut chunks: Vec<[BigUint; 16]> = Vec::with_capacity(chunk_count);
let mut salts: Vec<BigUint> = Vec::with_capacity(chunk_count);
let mut rng = thread_rng();
for chunk_of_bits in bits.chunks(chunk_size) {
// [BigUint::default(); 16] won't work since BigUint doesn't
// implement the Copy trait, so typing out all values
let mut chunk = [
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
];
// split chunk into 16 field elements
for (i, fe_bits) in chunk_of_bits.chunks(useful_bits).enumerate() {
if i < 15 {
chunk[i] = BigUint::from_bytes_be(&boolvec_to_u8vec(&fe_bits));
} else {
// last field element's last 128 bits are for the salt
let salt = rng.gen::<[u8; 16]>();
salts.push(BigUint::from_bytes_be(&salt));
let mut bits_and_salt = fe_bits.to_vec();
bits_and_salt.extend(u8vec_to_boolvec(&salt).iter());
chunk[15] = BigUint::from_bytes_be(&boolvec_to_u8vec(&bits_and_salt));
};
}
chunks.push(chunk);
}
(chunks, salts)
}
// hashes each chunk with Poseidon and returns digests for each chunk
fn hash_chunks(&mut self, chunks: Vec<[BigUint; 16]>) -> Vec<BigUint> {
let res: Vec<BigUint> = chunks
.iter()
.map(|chunk| self.poseidon(chunk.to_vec()))
.collect();
res
}
pub fn create_zk_proof(
&mut self,
zero_sum: Vec<BigUint>,
@@ -306,6 +260,77 @@ impl LsumProver {
}
Ok(proofs)
}
// create chunks of plaintext where each chunk consists of 16 field elements.
// The last element's last 128 bits are reserved for the salt of the hash.
// If there is not enough plaintext to fill the whole chunk, we fill the gap
// with zero bits.
pub fn plaintext_to_chunks(
&mut self,
useful_bits: usize,
plaintext: Vec<u8>,
) -> (usize, Vec<[BigUint; 16]>, Vec<BigUint>) {
// the size of a chunk of plaintext not counting the salt
let chunk_size = useful_bits * 16 - 128;
// plaintext converted into bits
let mut bits = u8vec_to_boolvec(&plaintext);
// chunk count (rounded up)
let chunk_count = (bits.len() + (chunk_size - 1)) / chunk_size;
// extend bits with zeroes to fill the last chunk
bits.extend(vec![false; chunk_count * chunk_size - bits.len()]);
let mut chunks: Vec<[BigUint; 16]> = Vec::with_capacity(chunk_count);
let mut salts: Vec<BigUint> = Vec::with_capacity(chunk_count);
let mut rng = thread_rng();
for chunk_of_bits in bits.chunks(chunk_size) {
// [BigUint::default(); 16] won't work since BigUint doesn't
// implement the Copy trait, so typing out all values
let mut chunk = [
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
BigUint::default(),
];
// split chunk into 16 field elements
for (i, fe_bits) in chunk_of_bits.chunks(useful_bits).enumerate() {
if i < 15 {
chunk[i] = BigUint::from_bytes_be(&boolvec_to_u8vec(&fe_bits));
} else {
// last field element's last 128 bits are for the salt
let salt = rng.gen::<[u8; 16]>();
salts.push(BigUint::from_bytes_be(&salt));
let mut bits_and_salt = fe_bits.to_vec();
bits_and_salt.extend(u8vec_to_boolvec(&salt).iter());
chunk[15] = BigUint::from_bytes_be(&boolvec_to_u8vec(&bits_and_salt));
};
}
chunks.push(chunk);
}
(chunk_size, chunks, salts)
}
// hashes each chunk with Poseidon and returns digests for each chunk
fn hash_chunks(&mut self, chunks: Vec<[BigUint; 16]>) -> Vec<BigUint> {
let res: Vec<BigUint> = chunks
.iter()
.map(|chunk| self.poseidon(chunk.to_vec()))
.collect();
res
}
}
/// Calculates how many bits of plaintext we will pack into one field element.
@@ -380,7 +405,7 @@ mod tests {
plaintext[(15 * 17 + 1) + i * 17 + 16] = (i + 16) as u8;
}
let mut prover = LsumProver::new(prime);
let mut prover = ProverCore::new(prime, None);
prover.setup(plaintext);
// Check chunk1 correctness
@@ -413,7 +438,7 @@ mod tests {
let prime = BigUint::from_bytes_be(&prime);
// plaintext will spawn 2 chunks
let mut plaintext = vec![0u8; 17 * 15 + 1 + 17 * 5];
let mut prover = LsumProver::new(prime);
let mut prover = ProverCore::new(prime, None);
//LsumProver::hash_chunks(&mut prover);
}