diff --git a/src/ghash/mod.rs b/src/ghash/mod.rs index 660fe28..c876aeb 100644 --- a/src/ghash/mod.rs +++ b/src/ghash/mod.rs @@ -9,10 +9,7 @@ pub use verifier::Verifier; use crate::ole::Ole; -pub fn ghash(blocks: &[Gf2_128], h_prover: Gf2_128, h_verifier: Gf2_128) -> Gf2_128 { - let mut prover = Prover::new(blocks.len(), h_prover); - let mut verifier = Verifier::new(blocks.len(), h_verifier); - +pub fn ghash(blocks: &[Gf2_128], prover: &mut Prover, verifier: &mut Verifier) -> Gf2_128 { let mut ole = Ole::default(); prover.preprocess_ole_input(&mut ole); @@ -71,10 +68,13 @@ mod tests { let h2: Gf2_128 = Gf2_128::rand(&mut rng); let h = h1 + h2; - let ghash = ghash(&blocks, h1, h2); + let mut prover = Prover::new(blocks.len(), h1); + let mut verifier = Verifier::new(blocks.len(), h2); + + let ghash = ghash(&blocks, &mut prover, &mut verifier); let ghash_expected = { - let mut hi = vec![Gf2_128::one(), h]; + let mut hi = vec![h]; compute_product_repeated(&mut hi, h, blocks.len()); blocks @@ -83,7 +83,24 @@ mod tests { .fold(Gf2_128::zero(), |acc, (&b, &h)| acc + (b * h)) }; - assert_eq!(ghash.to_be_bytes(), ghash_expected.to_be_bytes()); + assert_eq!(ghash, ghash_expected); + } + + #[test] + fn test_ghash_invariants() { + let mut rng = thread_rng(); + let blocks: Vec = (0..10).map(|_| Gf2_128::rand(&mut rng)).collect(); + + let h1: Gf2_128 = Gf2_128::rand(&mut rng); + let h2: Gf2_128 = Gf2_128::rand(&mut rng); + + let mut prover = Prover::new(blocks.len(), h1); + let mut verifier = Verifier::new(blocks.len(), h2); + + let _ = ghash(&blocks, &mut prover, &mut verifier); + + assert_eq!(prover.h1, prover.hi[1]); + assert_eq!(verifier.h2, verifier.hi[1]); } #[test] diff --git a/src/ghash/prover.rs b/src/ghash/prover.rs index 2888343..f76df5e 100644 --- a/src/ghash/prover.rs +++ b/src/ghash/prover.rs @@ -1,16 +1,19 @@ use super::pascal_tri; use crate::ole::{Ole, Role}; -use mpz_share_conversion_core::fields::{compute_product_repeated, gf2_128::Gf2_128, UniformRand}; +use mpz_share_conversion_core::{ + fields::{compute_product_repeated, gf2_128::Gf2_128, UniformRand}, + Field, +}; use rand::thread_rng; #[derive(Debug)] pub struct Prover { - block_num: usize, - h1: Gf2_128, - r1: Gf2_128, - ai: Vec, - d: Option, - hi: Vec, + pub(crate) block_num: usize, + pub(crate) h1: Gf2_128, + pub(crate) r1: Gf2_128, + pub(crate) ai: Vec, + pub(crate) d: Option, + pub(crate) hi: Vec, } impl Prover { @@ -48,20 +51,24 @@ impl Prover { } pub fn handshake_a_set_hi(&mut self) { - let mut di = vec![Gf2_128::new(1)]; + let mut di = vec![Gf2_128::one(), self.d.unwrap()]; compute_product_repeated(&mut di, self.d.unwrap(), self.block_num); let pascal_tri = pascal_tri::(self.block_num); - for k in 0..self.block_num { - for el in pascal_tri[k].iter() { - self.hi.push(*el * di[k] * self.ai[self.block_num - k]); - } + for pascal_row in pascal_tri.iter() { + let h_pow_share = pascal_row + .iter() + .enumerate() + .fold(Gf2_128::new(0), |acc, (i, &el)| { + acc + el * di[i] * self.ai[pascal_row.len() - 1 - i] + }); + self.hi.push(h_pow_share); } } pub fn handshake_output_ghash(&self, blocks: &[Gf2_128]) -> Gf2_128 { - let mut res = Gf2_128::new(0); + let mut res = Gf2_128::zero(); for (i, block) in blocks.iter().enumerate() { res = res + *block * self.hi[i]; diff --git a/src/ghash/verifier.rs b/src/ghash/verifier.rs index ceb7712..581a9d1 100644 --- a/src/ghash/verifier.rs +++ b/src/ghash/verifier.rs @@ -1,16 +1,19 @@ use super::pascal_tri; use crate::ole::{Ole, Role}; -use mpz_share_conversion_core::fields::{compute_product_repeated, gf2_128::Gf2_128, UniformRand}; +use mpz_share_conversion_core::{ + fields::{compute_product_repeated, gf2_128::Gf2_128, UniformRand}, + Field, +}; use rand::thread_rng; #[derive(Debug)] pub struct Verifier { - block_num: usize, - h2: Gf2_128, - r2: Gf2_128, - bi: Vec, - d: Option, - hi: Vec, + pub(crate) block_num: usize, + pub(crate) h2: Gf2_128, + pub(crate) r2: Gf2_128, + pub(crate) bi: Vec, + pub(crate) d: Option, + pub(crate) hi: Vec, } impl Verifier { @@ -48,15 +51,19 @@ impl Verifier { } pub fn handshake_a_set_hi(&mut self) { - let mut di = vec![Gf2_128::new(1)]; + let mut di = vec![Gf2_128::one(), self.d.unwrap()]; compute_product_repeated(&mut di, self.d.unwrap(), self.block_num); let pascal_tri = pascal_tri::(self.block_num); - for k in 0..self.block_num { - for el in pascal_tri[k].iter() { - self.hi.push(*el * di[k] * self.bi[self.block_num - k]); - } + for pascal_row in pascal_tri.iter() { + let h_pow_share = pascal_row + .iter() + .enumerate() + .fold(Gf2_128::new(0), |acc, (i, &el)| { + acc + el * di[i] * self.bi[pascal_row.len() - 1 - i] + }); + self.hi.push(h_pow_share); } }