chunking works with 1 chunk

This commit is contained in:
themighty1
2022-08-11 16:10:42 +03:00
parent 5dff68b194
commit 75d75ce49f
4 changed files with 192 additions and 134 deletions

View File

@@ -11,7 +11,7 @@ async function main(){
poseidonReference.F.toRprBE(buff, 0, res);
const rv = bufToBn(buff);
// print to stdout. This is how Rust gets the output
// print to stdout. This is how the Rust caller reads the output
console.log(bufToBn(buff).toString());
}

View File

@@ -53,7 +53,7 @@ mod tests {
ots.setup().unwrap();
// random 490 byte plaintext. This is the size of one chunk.
// Our Poseidon is 16-arity * 253 bits each - 128 bits (salt) == 490 bytes
// Our Poseidon is 16-width * 253 bits each - 128 bits (salt) == 490 bytes
let mut plaintext = [0u8; 512];
rng.fill(&mut plaintext);
let plaintext = &plaintext[0..320];
@@ -87,15 +87,15 @@ mod tests {
let cipheretexts = verifier.receive_pt_hashes(plaintext_hash);
// Verifier sends back encrypted arithm. labels.
let label_sum_hash = prover.compute_label_sum(&cipheretexts, &prover_labels);
let label_sum_hashes = prover.compute_label_sum(&cipheretexts, &prover_labels);
// Hash commitment to the label_sum is sent to the Notary
let (deltas, zero_sum) = verifier.receive_labelsum_hash(label_sum_hash);
let (deltas, zero_sums) = verifier.receive_labelsum_hash(label_sum_hashes);
// Notary sends zero_sum and all deltas
// Prover constructs input to snarkjs
let proof = prover.create_zk_proof(zero_sum, deltas).unwrap();
let proofs = prover.create_zk_proof(zero_sums, deltas).unwrap();
// Verifier verifies the proof
assert_eq!(verifier.verify(proof).unwrap(), true);
assert_eq!(verifier.verify(proofs).unwrap(), true);
}
}

View File

@@ -46,7 +46,8 @@ pub struct LsumProver {
// brute-forced.
salts: Option<Vec<BigUint>>,
// hash of all our arithmetic labels
label_sum_hash: Option<BigUint>,
label_sum_hashes: Option<Vec<BigUint>>,
chunk_size: Option<usize>,
}
impl LsumProver {
@@ -64,7 +65,8 @@ impl LsumProver {
chunks: None,
salts: None,
hashes_of_chunks: None,
label_sum_hash: None,
label_sum_hashes: None,
chunk_size: None,
}
}
@@ -89,20 +91,27 @@ impl LsumProver {
}
// decrypt each encrypted arithm.label based on the p&p bit of our active
// binary label. Return the hash of the sum of all arithm. labels.
// binary label. Return the hash of the sum of all arithm. labels. Note
// that we compute a separate label sum for each chunk of plaintext.
pub fn compute_label_sum(
&mut self,
ciphertexts: &Vec<[Vec<u8>; 2]>,
labels: &Vec<u128>,
) -> BigUint {
) -> Vec<BigUint> {
// if binary label's p&p bit is 0, decrypt the 1st ciphertext,
// otherwise - the 2nd one.
assert!(ciphertexts.len() == labels.len());
let mut label_sum = BigUint::from_u8(0).unwrap();
let _unused: Vec<()> = ciphertexts
.iter()
.zip(labels)
.map(|(ct_pair, label)| {
assert!(self.plaintext.as_ref().unwrap().len() * 8 == ciphertexts.len());
let mut label_sum_hashes: Vec<BigUint> =
Vec::with_capacity(self.chunks.as_ref().unwrap().len());
let ct_iter = ciphertexts.chunks(self.chunk_size.unwrap());
let lb_iter = labels.chunks(self.chunk_size.unwrap());
// process a pair of chunks of ciphertexts and corresponding labels at a time
for (chunk_ct, chunk_lb) in ct_iter.zip(lb_iter){
// accumulate the label sum here
let mut label_sum = BigUint::from_u8(0).unwrap();
for (ct_pair, label) in chunk_ct.iter().zip(chunk_lb) {
let key = Aes128::new_from_slice(&label.to_be_bytes()).unwrap();
// choose which ciphertext to decrypt based on the point-and-permute bit
let mut ct = [0u8; 16];
@@ -115,12 +124,15 @@ impl LsumProver {
key.decrypt_block(&mut ct);
// add the decrypted arithmetic label to the sum
label_sum += BigUint::from_bytes_be(&ct);
})
.collect();
println!("{:?} label_sum", label_sum);
let label_sum_hash = self.poseidon(vec![label_sum]);
self.label_sum_hash = Some(label_sum_hash.clone());
label_sum_hash
};
println!("{:?} label_sum", label_sum);
let label_sum_hash = self.poseidon(vec![label_sum]);
label_sum_hashes.push(label_sum_hash);
}
self.label_sum_hashes = Some(label_sum_hashes.clone());
label_sum_hashes
}
// create chunks of plaintext where each chunk consists of 16 field elements.
@@ -131,6 +143,7 @@ impl LsumProver {
let useful_bits = self.useful_bits.unwrap();
// the size of a chunk of plaintext not counting the salt
let chunk_size = useful_bits * 16 - 128;
self.chunk_size = Some(chunk_size);
//let chunk_size = useful_bits * 16;
// plaintext converted into bits
@@ -193,10 +206,11 @@ impl LsumProver {
// hashes each chunk with Poseidon and returns digests for each chunk
fn hash_chunks(&mut self, chunks: Vec<[BigUint; 16]>) -> Vec<BigUint> {
return chunks
let res: Vec<BigUint> = chunks
.iter()
.map(|chunk| self.poseidon(chunk.to_vec()))
.collect();
res
}
// hash the inputs with circomlibjs's Poseidon
@@ -219,16 +233,16 @@ impl LsumProver {
let output = &output.stdout[0..output.stdout.len() - 1];
let s = String::from_utf8(output.to_vec()).unwrap();
let bi = s.parse::<BigUint>().unwrap();
println!("poseidon output {:?}", bi);
//println!("poseidon output {:?}", bi);
bi
}
pub fn create_zk_proof(
&mut self,
zero_sum: BigUint,
zero_sum: Vec<BigUint>,
deltas: Vec<BigUint>,
) -> Result<Vec<u8>, Error> {
let label_sum_hash = self.label_sum_hash.as_ref().unwrap().clone();
) -> Result<Vec<Vec<u8>>, Error> {
let label_sum_hashes = self.label_sum_hashes.as_ref().unwrap().clone();
// the last chunk will be padded with zero plaintext. We also should pad
// the deltas of the last chunk
@@ -243,54 +257,66 @@ impl LsumProver {
padded_deltas.extend(deltas);
padded_deltas.extend(padding);
// write inputs into input.json
let pt_str: Vec<String> = self.chunks.as_ref().unwrap()[0]
.to_vec()
.iter()
.map(|bigint| bigint.to_string())
.collect();
// For now dealing with one chunk only
let mut deltas_chunk: Vec<Vec<BigUint>> = Vec::with_capacity(16);
for i in 0..15 {
deltas_chunk.push(padded_deltas[i * 253..(i + 1) * 253].to_vec());
// we will have as many proofs as there are chunks of plaintext
let chunk_count = self.chunks.as_ref().unwrap().len();
let mut proofs: Vec<Vec<u8>> = Vec::with_capacity(chunk_count);
for count in 0..chunk_count {
// write inputs into input.json
let pt_str: Vec<String> = self.chunks.as_ref().unwrap()[count]
.to_vec()
.iter()
.map(|bigint| bigint.to_string())
.collect();
// For now dealing with one chunk only
let mut deltas_chunk: Vec<Vec<BigUint>> = Vec::with_capacity(16);
for i in 0..15 {
deltas_chunk.push(
padded_deltas[count * chunk_size + i * 253..count * chunk_size + (i + 1) * 253]
.to_vec(),
);
}
// There are as many deltas as there are bits in the plaintext
let delta_str: Vec<Vec<String>> = deltas_chunk
.iter()
.map(|v| v.iter().map(|b| b.to_string()).collect())
.collect();
let delta_last =
&padded_deltas[count * chunk_size + 15 * 253..count * chunk_size + 16 * 253 - 128];
let delta_last_str: Vec<String> = delta_last.iter().map(|v| v.to_string()).collect();
let mut data = object! {
plaintext_hash: self.hashes_of_chunks.as_ref().unwrap()[count].to_string(),
label_sum_hash: label_sum_hashes[count].to_string(),
sum_of_zero_labels: zero_sum[count].to_string(),
plaintext: pt_str,
delta: delta_str,
delta_last: delta_last_str
};
let s = stringify_pretty(data, 4);
let mut path1 = temp_dir();
let mut path2 = temp_dir();
path1.push(format!("input.json.{}", Uuid::new_v4()));
path2.push(format!("proof.json.{}", Uuid::new_v4()));
fs::write(path1.clone(), s).expect("Unable to write file");
let output = Command::new("node")
.args([
"prove.mjs",
path1.to_str().unwrap(),
path2.to_str().unwrap(),
])
.output();
fs::remove_file(path1);
check_output(output)?;
let proof = fs::read(path2.clone()).unwrap();
fs::remove_file(path2);
proofs.push(proof);
}
// There are as many deltas as there are bits in the plaintext
let delta_str: Vec<Vec<String>> = deltas_chunk
.iter()
.map(|v| v.iter().map(|b| b.to_string()).collect())
.collect();
let delta_last = &padded_deltas[15 * 253..16 * 253 - 128];
let delta_last_str: Vec<String> = delta_last.iter().map(|v| v.to_string()).collect();
let mut data = object! {
plaintext_hash: self.hashes_of_chunks.as_ref().unwrap()[0].to_string(),
label_sum_hash: label_sum_hash.to_string(),
sum_of_zero_labels: zero_sum.to_string(),
plaintext: pt_str,
delta: delta_str,
delta_last: delta_last_str
};
let s = stringify_pretty(data, 4);
let mut path1 = temp_dir();
let mut path2 = temp_dir();
path1.push(format!("input.json.{}", Uuid::new_v4()));
path2.push(format!("proof.json.{}", Uuid::new_v4()));
fs::write(path1.clone(), s).expect("Unable to write file");
let output = Command::new("node")
.args([
"prove.mjs",
path1.to_str().unwrap(),
path2.to_str().unwrap(),
])
.output();
fs::remove_file(path1);
check_output(output)?;
let proof = fs::read(path2.clone()).unwrap();
fs::remove_file(path2);
Ok(proof)
Ok(proofs)
}
}

View File

@@ -31,13 +31,13 @@ fn check_output(output: &Result<Output, std::io::Error>) -> Result<(), Error> {
pub struct LsumVerifier {
// hashes for each chunk of Prover's plaintext
plaintext_hashes: Option<Vec<BigUint>>,
labelsum_hash: Option<BigUint>,
labelsum_hashes: Option<Vec<BigUint>>,
// if set to true, then we must send the proving key to the Prover
// before this protocol begins. Otherwise, it is assumed that the Prover
// already has the proving key from a previous interaction with us.
proving_key_needed: bool,
deltas: Option<Vec<BigUint>>,
zero_sum: Option<BigUint>,
zero_sums: Option<Vec<BigUint>>,
ciphertexts: Option<Vec<[Vec<u8>; 2]>>,
useful_bits: usize,
}
@@ -46,10 +46,10 @@ impl LsumVerifier {
pub fn new(proving_key_needed: bool) -> Self {
Self {
plaintext_hashes: None,
labelsum_hash: None,
labelsum_hashes: None,
proving_key_needed,
deltas: None,
zero_sum: None,
zero_sums: None,
ciphertexts: None,
useful_bits: 253,
}
@@ -74,27 +74,54 @@ impl LsumVerifier {
// call to AES.
// To keep the handling simple, we want to avoid a negative delta, that's why
// W_0 and delta must be 127-bit values and W_1 will be set to W_0 + delta
let bitsize = labels.len();
let mut zero_sum = BigUint::from_u8(0).unwrap();
let chunk_size = 253 * 16 - 128;
let chunk_count = (bitsize + (chunk_size - 1)) / chunk_size;
let mut zero_sums: Vec<BigUint> = Vec::with_capacity(chunk_count);
let mut deltas: Vec<BigUint> = Vec::with_capacity(bitsize);
let arithm_labels: Vec<[BigUint; 2]> = (0..bitsize)
.map(|_| {
let zero_label = random_bigint(127);
let delta = random_bigint(127);
let one_label = zero_label.clone() + delta.clone();
zero_sum += zero_label.clone();
deltas.push(delta);
[zero_label, one_label]
})
.collect();
self.zero_sum = Some(zero_sum);
let mut all_arithm_labels: Vec<[BigUint; 2]> = Vec::with_capacity(bitsize);
for count in 0..chunk_count {
// calculate zero_sum for each chunk of plaintext separately
let mut zero_sum = BigUint::from_u8(0).unwrap();
// end of range is different for the last chunk
let end = if count < chunk_count - 1 {
(count + 1) * chunk_size
} else {
// compute the size of the gap at the end of the last chunk
let last_size = bitsize % chunk_size;
let gap_size = if last_size == 0 {
0
} else {
chunk_size - last_size
};
(count + 1) * chunk_size - gap_size
};
all_arithm_labels.append(
&mut (count * chunk_size..end)
.map(|_| {
let zero_label = random_bigint(127);
let delta = random_bigint(127);
let one_label = zero_label.clone() + delta.clone();
zero_sum += zero_label.clone();
deltas.push(delta);
[zero_label, one_label]
})
.collect(),
);
zero_sums.push(zero_sum);
}
self.zero_sums = Some(zero_sums);
self.deltas = Some(deltas);
// flatten all arithmetic labels
// encrypt each arithmetic label using a corresponding binary label as a key
// place ciphertexts in an order based on binary label's p&p bit
let ciphertexts: Vec<[Vec<u8>; 2]> = labels
.iter()
.zip(arithm_labels)
.zip(all_arithm_labels)
.map(|(bin_pair, arithm_pair)| {
let zero_key = Aes128::new_from_slice(&bin_pair[0].to_be_bytes()).unwrap();
let one_key = Aes128::new_from_slice(&bin_pair[1].to_be_bytes()).unwrap();
@@ -132,15 +159,15 @@ impl LsumVerifier {
// receive the hash commitment to the Prover's sum of labels and reveal all
// deltas and zero_sum.
pub fn receive_labelsum_hash(&mut self, hash: BigUint) -> (Vec<BigUint>, BigUint) {
self.labelsum_hash = Some(hash);
pub fn receive_labelsum_hash(&mut self, hashes: Vec<BigUint>) -> (Vec<BigUint>, Vec<BigUint>) {
self.labelsum_hashes = Some(hashes);
(
self.deltas.as_ref().unwrap().clone(),
self.zero_sum.as_ref().unwrap().clone(),
self.zero_sums.as_ref().unwrap().clone(),
)
}
pub fn verify(&mut self, proof: Vec<u8>) -> Result<bool, Error> {
pub fn verify(&mut self, proofs: Vec<Vec<u8>>) -> Result<bool, Error> {
// // Write public.json. The elements must be written in the exact order
// // as below, that's the order snarkjs expects them to be in.
@@ -158,6 +185,8 @@ impl LsumVerifier {
padded_deltas.extend(self.deltas.as_ref().unwrap().clone());
padded_deltas.extend(padding);
assert!(proofs.len() == chunk_count);
let mut chunks: Vec<Vec<Vec<BigUint>>> = Vec::with_capacity(chunk_count);
// current offset within bits
let mut offset: usize = 0;
@@ -172,51 +201,54 @@ impl LsumVerifier {
chunks.push(chunk);
}
// Even though there may be multiple chunks, we are dealing with
// one chunk for now.
for count in 0..chunk_count {
// There are as many deltas as there are bits in the plaintext
let delta_str: Vec<Vec<String>> = chunks[count][0..15]
.iter()
.map(|v| v.iter().map(|b| b.to_string()).collect())
.collect();
let delta_last_str: Vec<String> = chunks[0][15].iter().map(|v| v.to_string()).collect();
// There are as many deltas as there are bits in the plaintext
let delta_str: Vec<Vec<String>> = chunks[0][0..15]
.iter()
.map(|v| v.iter().map(|b| b.to_string()).collect())
.collect();
let delta_last_str: Vec<String> = chunks[0][15].iter().map(|v| v.to_string()).collect();
// public.json is a flat array
let mut public_json: Vec<String> = Vec::new();
public_json.push(
self.plaintext_hashes.as_ref().unwrap()[0]
.clone()
.to_string(),
);
public_json.push(
self.labelsum_hashes.as_ref().unwrap()[count]
.clone()
.to_string(),
);
public_json.extend::<Vec<String>>(delta_str.into_iter().flatten().collect());
public_json.extend(delta_last_str);
public_json.push(self.zero_sums.as_ref().unwrap()[count].clone().to_string());
// public.json is a flat array
let mut public_json: Vec<String> = Vec::new();
public_json.push(
self.plaintext_hashes.as_ref().unwrap()[0]
.clone()
.to_string(),
);
public_json.push(self.labelsum_hash.as_ref().unwrap().clone().to_string());
public_json.extend::<Vec<String>>(delta_str.into_iter().flatten().collect());
public_json.extend(delta_last_str);
public_json.push(self.zero_sum.as_ref().unwrap().clone().to_string());
let s = stringify(JsonValue::from(public_json.clone()));
let s = stringify(JsonValue::from(public_json.clone()));
let mut path1 = temp_dir();
let mut path2 = temp_dir();
path1.push(format!("public.json.{}", Uuid::new_v4()));
path2.push(format!("proof.json.{}", Uuid::new_v4()));
fs::write(path1.clone(), s).expect("Unable to write file");
fs::write(path2.clone(), proofs[count].clone()).expect("Unable to write file");
let mut path1 = temp_dir();
let mut path2 = temp_dir();
path1.push(format!("public.json.{}", Uuid::new_v4()));
path2.push(format!("proof.json.{}", Uuid::new_v4()));
fs::write(path1.clone(), s).expect("Unable to write file");
fs::write(path2.clone(), proof).expect("Unable to write file");
let output = Command::new("node")
.args([
"verify.mjs",
path1.to_str().unwrap(),
path2.to_str().unwrap(),
])
.output();
fs::remove_file(path1);
fs::remove_file(path2);
check_output(&output)?;
if output.unwrap().status.success() {
return Ok(true);
let output = Command::new("node")
.args([
"verify.mjs",
path1.to_str().unwrap(),
path2.to_str().unwrap(),
])
.output();
fs::remove_file(path1);
fs::remove_file(path2);
check_output(&output)?;
if !output.unwrap().status.success() {
return Ok(false);
}
}
return Ok(false);
return Ok(true);
}
}