mirror of
https://github.com/tlsnotary/label_decoding.git
synced 2026-01-10 04:27:56 -05:00
wip commit to figure out lifetimes
This commit is contained in:
37
src/lib.rs
37
src/lib.rs
@@ -1,5 +1,5 @@
|
||||
use num::{BigUint, FromPrimitive, ToPrimitive, Zero};
|
||||
use prover::ProverError;
|
||||
use prover::{ProverCore, ProverError};
|
||||
use verifier::VerifierError;
|
||||
|
||||
pub mod onetimesetup;
|
||||
@@ -11,8 +11,9 @@ pub mod verifier;
|
||||
const BN254_PRIME: &str =
|
||||
"21888242871839275222246405745257275088548364400416034343698204186575808495617";
|
||||
|
||||
// ProverCore must be implemented by the nodejs and wasm backends
|
||||
pub trait ProverCore {
|
||||
// ProverVirtual describes virtual methods which must be implemented by the
|
||||
// nodejs and wasm implementors
|
||||
pub trait ProverVirtual {
|
||||
fn set_proving_key(&mut self, key: Vec<u8>) -> Result<(), ProverError>;
|
||||
|
||||
fn poseidon(&mut self, inputs: Vec<BigUint>) -> BigUint;
|
||||
@@ -20,6 +21,32 @@ pub trait ProverCore {
|
||||
fn prove(&mut self, input: String) -> Result<Vec<u8>, ProverError>;
|
||||
}
|
||||
|
||||
// provides default implementations for methods which must passthrough to the
|
||||
// LsumProverCore. Only get_core() must be implemented.
|
||||
pub trait ProverPassthrough {
|
||||
fn setup(&mut self, plaintext: Vec<u8>) -> Vec<BigUint> {
|
||||
self.get_core().setup(plaintext)
|
||||
}
|
||||
|
||||
fn compute_label_sum(
|
||||
&mut self,
|
||||
ciphertexts: &Vec<[Vec<u8>; 2]>,
|
||||
labels: &Vec<u128>,
|
||||
) -> Vec<BigUint> {
|
||||
self.get_core().compute_label_sum(ciphertexts, labels)
|
||||
}
|
||||
|
||||
fn create_zk_proof(
|
||||
&mut self,
|
||||
zero_sum: Vec<BigUint>,
|
||||
mut deltas: Vec<BigUint>,
|
||||
) -> Result<Vec<Vec<u8>>, ProverError> {
|
||||
self.get_core().create_zk_proof(zero_sum, deltas)
|
||||
}
|
||||
|
||||
fn get_core(&mut self) -> &ProverCore<Box<Self>>;
|
||||
}
|
||||
|
||||
pub trait VerifierCore {
|
||||
fn get_proving_key(&mut self) -> Result<Vec<u8>, VerifierError>;
|
||||
|
||||
@@ -42,7 +69,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use num::{BigUint, FromPrimitive};
|
||||
use prover::LsumProver;
|
||||
use prover::Prover;
|
||||
use rand::{thread_rng, Rng, RngCore};
|
||||
use verifier::LsumVerifier;
|
||||
|
||||
@@ -99,7 +126,7 @@ mod tests {
|
||||
// produce ciphertexts which are sent to Prover for decryption
|
||||
verifier.setup(&all_binary_labels);
|
||||
|
||||
let mut prover = LsumProver::new(prime);
|
||||
let mut prover = Prover::new(prime);
|
||||
prover.set_proving_key(proving_key);
|
||||
let plaintext_hash = prover.setup(plaintext.to_vec());
|
||||
|
||||
|
||||
223
src/prover.rs
223
src/prover.rs
@@ -16,7 +16,7 @@ pub enum ProverError {
|
||||
SnarkjsError,
|
||||
}
|
||||
|
||||
use crate::ProverCore;
|
||||
use crate::{ProverPassthrough, ProverVirtual};
|
||||
|
||||
use super::BN254_PRIME;
|
||||
|
||||
@@ -30,30 +30,28 @@ fn check_output(output: Result<Output, std::io::Error>) -> Result<(), ProverErro
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// implementation of the Prover in the "label_sum" protocol (aka the User).
|
||||
pub struct LsumProver {
|
||||
// bytes of the plaintext which was obtained from the garbled circuit
|
||||
plaintext: Option<Vec<u8>>,
|
||||
// the prime of the field in which Poseidon hash will be computed.
|
||||
field_prime: BigUint,
|
||||
// how many bits to pack into one field element
|
||||
useful_bits: Option<usize>,
|
||||
// the size of one chunk == useful_bits * Poseidon_width - 128 (salt size)
|
||||
chunk_size: Option<usize>,
|
||||
// We will compute a separate Poseidon hash on each chunk of the plaintext.
|
||||
// Each chunk contains 16 field elements.
|
||||
chunks: Option<Vec<[BigUint; 16]>>,
|
||||
// Poseidon hashes of each chunk
|
||||
hashes_of_chunks: Option<Vec<BigUint>>,
|
||||
// each chunk's last 128 bits are used for the salt. This is important for
|
||||
// security b/c w/o the salt, hashes of plaintext with low entropy could be
|
||||
// brute-forced.
|
||||
salts: Option<Vec<BigUint>>,
|
||||
// hash of all our arithmetic labels
|
||||
label_sum_hashes: Option<Vec<BigUint>>,
|
||||
pub struct Prover<'a> {
|
||||
core: ProverCore<&'a Prover>,
|
||||
}
|
||||
|
||||
impl ProverCore for LsumProver {
|
||||
impl Prover {
|
||||
pub fn new(field_prime: BigUint) -> Self {
|
||||
let core = ProverCore::new(field_prime, Self);
|
||||
Self { core }
|
||||
}
|
||||
|
||||
pub fn setup(&mut self, plaintext: Vec<u8>) -> Vec<BigUint> {
|
||||
self.core.setup(plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProverPassthrough for Prover {
|
||||
fn get_core(&mut self) -> &ProverCore<Box<Prover>> {
|
||||
&self.core
|
||||
}
|
||||
}
|
||||
|
||||
impl ProverVirtual for Prover {
|
||||
fn set_proving_key(&mut self, key: Vec<u8>) -> Result<(), ProverError> {
|
||||
let res = fs::write("circuit_final.zkey.verifier", key);
|
||||
if res.is_err() {
|
||||
@@ -108,8 +106,31 @@ impl ProverCore for LsumProver {
|
||||
}
|
||||
}
|
||||
|
||||
impl LsumProver {
|
||||
pub fn new(field_prime: BigUint) -> Self {
|
||||
// implementation of the Prover in the "label_sum" protocol (aka the User).
|
||||
pub struct ProverCore<'a, T: 'a> {
|
||||
// bytes of the plaintext which was obtained from the garbled circuit
|
||||
plaintext: Option<Vec<u8>>,
|
||||
// the prime of the field in which Poseidon hash will be computed.
|
||||
field_prime: BigUint,
|
||||
// how many bits to pack into one field element
|
||||
useful_bits: Option<usize>,
|
||||
// the size of one chunk == useful_bits * Poseidon_width - 128 (salt size)
|
||||
chunk_size: Option<usize>,
|
||||
// We will compute a separate Poseidon hash on each chunk of the plaintext.
|
||||
// Each chunk contains 16 field elements.
|
||||
chunks: Option<Vec<[BigUint; 16]>>,
|
||||
// Poseidon hashes of each chunk
|
||||
hashes_of_chunks: Option<Vec<BigUint>>,
|
||||
// each chunk's last 128 bits are used for the salt. w/o the salt, hashes
|
||||
// of plaintext with low entropy could be brute-forced.
|
||||
salts: Option<Vec<BigUint>>,
|
||||
// hash of all our arithmetic labels
|
||||
label_sum_hashes: Option<Vec<BigUint>>,
|
||||
caller: &'a T,
|
||||
}
|
||||
|
||||
impl<T> ProverCore<T> {
|
||||
pub fn new(field_prime: BigUint, caller: &T) -> Self {
|
||||
if field_prime.bits() < 129 {
|
||||
// last field element must be large enough to contain the 128-bit
|
||||
// salt. In the future, if we need to support fields < 129 bits,
|
||||
@@ -125,17 +146,19 @@ impl LsumProver {
|
||||
hashes_of_chunks: None,
|
||||
label_sum_hashes: None,
|
||||
chunk_size: None,
|
||||
caller,
|
||||
}
|
||||
}
|
||||
|
||||
// Return hash digests which is Prover's commitment to the plaintext
|
||||
pub fn setup(&mut self, plaintext: Vec<u8>) -> Vec<BigUint> {
|
||||
self.plaintext = Some(plaintext);
|
||||
self.plaintext = Some(plaintext.clone());
|
||||
let useful_bits = calculate_useful_bits(self.field_prime.clone());
|
||||
self.useful_bits = Some(useful_bits);
|
||||
let (chunks, salts) = self.plaintext_to_chunks();
|
||||
let (chunk_size, chunks, salts) = self.plaintext_to_chunks(useful_bits, plaintext);
|
||||
self.chunks = Some(chunks.clone());
|
||||
self.salts = Some(salts);
|
||||
self.chunk_size = Some(chunk_size);
|
||||
let hashes = self.hash_chunks(chunks);
|
||||
self.hashes_of_chunks = Some(hashes.clone());
|
||||
hashes
|
||||
@@ -150,7 +173,7 @@ impl LsumProver {
|
||||
labels: &Vec<u128>,
|
||||
) -> Vec<BigUint> {
|
||||
// if binary label's p&p bit is 0, decrypt the 1st ciphertext,
|
||||
// otherwise - the 2nd one.
|
||||
// otherwise decrypt the 2nd one.
|
||||
assert!(ciphertexts.len() == labels.len());
|
||||
assert!(self.plaintext.as_ref().unwrap().len() * 8 == ciphertexts.len());
|
||||
let mut label_sum_hashes: Vec<BigUint> =
|
||||
@@ -185,75 +208,6 @@ impl LsumProver {
|
||||
label_sum_hashes
|
||||
}
|
||||
|
||||
// create chunks of plaintext where each chunk consists of 16 field elements.
|
||||
// The last element's last 128 bits are reserved for the salt of the hash.
|
||||
// If there is not enough plaintext to fill the whole chunk, we fill the gap
|
||||
// with zero bits.
|
||||
fn plaintext_to_chunks(&mut self) -> (Vec<[BigUint; 16]>, Vec<BigUint>) {
|
||||
let useful_bits = self.useful_bits.unwrap();
|
||||
// the size of a chunk of plaintext not counting the salt
|
||||
let chunk_size = useful_bits * 16 - 128;
|
||||
self.chunk_size = Some(chunk_size);
|
||||
|
||||
// plaintext converted into bits
|
||||
let mut bits = u8vec_to_boolvec(&self.plaintext.as_ref().unwrap());
|
||||
// chunk count (rounded up)
|
||||
let chunk_count = (bits.len() + (chunk_size - 1)) / chunk_size;
|
||||
// extend bits with zeroes to fill the last chunk
|
||||
bits.extend(vec![false; chunk_count * chunk_size - bits.len()]);
|
||||
let mut chunks: Vec<[BigUint; 16]> = Vec::with_capacity(chunk_count);
|
||||
let mut salts: Vec<BigUint> = Vec::with_capacity(chunk_count);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
for chunk_of_bits in bits.chunks(chunk_size) {
|
||||
// [BigUint::default(); 16] won't work since BigUint doesn't
|
||||
// implement the Copy trait, so typing out all values
|
||||
let mut chunk = [
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
];
|
||||
|
||||
// split chunk into 16 field elements
|
||||
for (i, fe_bits) in chunk_of_bits.chunks(useful_bits).enumerate() {
|
||||
if i < 15 {
|
||||
chunk[i] = BigUint::from_bytes_be(&boolvec_to_u8vec(&fe_bits));
|
||||
} else {
|
||||
// last field element's last 128 bits are for the salt
|
||||
let salt = rng.gen::<[u8; 16]>();
|
||||
salts.push(BigUint::from_bytes_be(&salt));
|
||||
let mut bits_and_salt = fe_bits.to_vec();
|
||||
bits_and_salt.extend(u8vec_to_boolvec(&salt).iter());
|
||||
chunk[15] = BigUint::from_bytes_be(&boolvec_to_u8vec(&bits_and_salt));
|
||||
};
|
||||
}
|
||||
chunks.push(chunk);
|
||||
}
|
||||
(chunks, salts)
|
||||
}
|
||||
|
||||
// hashes each chunk with Poseidon and returns digests for each chunk
|
||||
fn hash_chunks(&mut self, chunks: Vec<[BigUint; 16]>) -> Vec<BigUint> {
|
||||
let res: Vec<BigUint> = chunks
|
||||
.iter()
|
||||
.map(|chunk| self.poseidon(chunk.to_vec()))
|
||||
.collect();
|
||||
res
|
||||
}
|
||||
|
||||
pub fn create_zk_proof(
|
||||
&mut self,
|
||||
zero_sum: Vec<BigUint>,
|
||||
@@ -306,6 +260,77 @@ impl LsumProver {
|
||||
}
|
||||
Ok(proofs)
|
||||
}
|
||||
|
||||
// create chunks of plaintext where each chunk consists of 16 field elements.
|
||||
// The last element's last 128 bits are reserved for the salt of the hash.
|
||||
// If there is not enough plaintext to fill the whole chunk, we fill the gap
|
||||
// with zero bits.
|
||||
pub fn plaintext_to_chunks(
|
||||
&mut self,
|
||||
useful_bits: usize,
|
||||
plaintext: Vec<u8>,
|
||||
) -> (usize, Vec<[BigUint; 16]>, Vec<BigUint>) {
|
||||
// the size of a chunk of plaintext not counting the salt
|
||||
let chunk_size = useful_bits * 16 - 128;
|
||||
|
||||
// plaintext converted into bits
|
||||
let mut bits = u8vec_to_boolvec(&plaintext);
|
||||
// chunk count (rounded up)
|
||||
let chunk_count = (bits.len() + (chunk_size - 1)) / chunk_size;
|
||||
// extend bits with zeroes to fill the last chunk
|
||||
bits.extend(vec![false; chunk_count * chunk_size - bits.len()]);
|
||||
let mut chunks: Vec<[BigUint; 16]> = Vec::with_capacity(chunk_count);
|
||||
let mut salts: Vec<BigUint> = Vec::with_capacity(chunk_count);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
for chunk_of_bits in bits.chunks(chunk_size) {
|
||||
// [BigUint::default(); 16] won't work since BigUint doesn't
|
||||
// implement the Copy trait, so typing out all values
|
||||
let mut chunk = [
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
BigUint::default(),
|
||||
];
|
||||
|
||||
// split chunk into 16 field elements
|
||||
for (i, fe_bits) in chunk_of_bits.chunks(useful_bits).enumerate() {
|
||||
if i < 15 {
|
||||
chunk[i] = BigUint::from_bytes_be(&boolvec_to_u8vec(&fe_bits));
|
||||
} else {
|
||||
// last field element's last 128 bits are for the salt
|
||||
let salt = rng.gen::<[u8; 16]>();
|
||||
salts.push(BigUint::from_bytes_be(&salt));
|
||||
let mut bits_and_salt = fe_bits.to_vec();
|
||||
bits_and_salt.extend(u8vec_to_boolvec(&salt).iter());
|
||||
chunk[15] = BigUint::from_bytes_be(&boolvec_to_u8vec(&bits_and_salt));
|
||||
};
|
||||
}
|
||||
chunks.push(chunk);
|
||||
}
|
||||
(chunk_size, chunks, salts)
|
||||
}
|
||||
|
||||
// hashes each chunk with Poseidon and returns digests for each chunk
|
||||
fn hash_chunks(&mut self, chunks: Vec<[BigUint; 16]>) -> Vec<BigUint> {
|
||||
let res: Vec<BigUint> = chunks
|
||||
.iter()
|
||||
.map(|chunk| self.poseidon(chunk.to_vec()))
|
||||
.collect();
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates how many bits of plaintext we will pack into one field element.
|
||||
@@ -380,7 +405,7 @@ mod tests {
|
||||
plaintext[(15 * 17 + 1) + i * 17 + 16] = (i + 16) as u8;
|
||||
}
|
||||
|
||||
let mut prover = LsumProver::new(prime);
|
||||
let mut prover = ProverCore::new(prime, None);
|
||||
prover.setup(plaintext);
|
||||
|
||||
// Check chunk1 correctness
|
||||
@@ -413,7 +438,7 @@ mod tests {
|
||||
let prime = BigUint::from_bytes_be(&prime);
|
||||
// plaintext will spawn 2 chunks
|
||||
let mut plaintext = vec![0u8; 17 * 15 + 1 + 17 * 5];
|
||||
let mut prover = LsumProver::new(prime);
|
||||
let mut prover = ProverCore::new(prime, None);
|
||||
//LsumProver::hash_chunks(&mut prover);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user