diff --git a/tls-2pc-core/Cargo.toml b/tls-2pc-core/Cargo.toml index af40df18b..6809c942d 100644 --- a/tls-2pc-core/Cargo.toml +++ b/tls-2pc-core/Cargo.toml @@ -33,6 +33,7 @@ rand.workspace = true thiserror.workspace = true serde = { workspace = true, features = ["derive"], optional = true } once_cell.workspace = true +share-conversion-core = { path = "../share-conversion-core" } [dev-dependencies] criterion.workspace = true diff --git a/tls-2pc-core/src/ghash/common.rs b/tls-2pc-core/src/ghash/common.rs deleted file mode 100644 index 8da485cc1..000000000 --- a/tls-2pc-core/src/ghash/common.rs +++ /dev/null @@ -1,153 +0,0 @@ -use super::{ - errors::*, - utils::{find_max_odd_power, find_sum, square_all}, -}; -use std::collections::BTreeMap; - -// GhashCommon is common to both Master and Slave -pub struct GhashCommon { - // blocks are input blocks for GHASH. In TLS the first block is AAD, - // the middle blocks are AES blocks - the ciphertext, the last block - // is len(AAD)+len(ciphertext) - pub blocks: Vec, - // powers are our XOR shares of the powers of H (H is the GHASH key). - // We need as many powers as there blocks. Value at key==1 corresponds to the share - // of H^1, value at key==2 to the share of H^2 etc. - pub powers: BTreeMap, - // max_odd_power is the maximum odd power that we'll need to compute - // GHASH in 2PC using Block Aggregation - pub max_odd_power: u8, - // strategies are initialized in ::new(). See comments there. - pub strategies: [BTreeMap; 2], - // temp_share is used to save an intermediate GHASH share - pub temp_share: Option, -} - -impl GhashCommon { - pub fn new(ghash_key_share: u128, blocks: Vec) -> Result { - if blocks.len() < 3 || blocks.len() > 1026 { - return Err(GhashError::BlockCountWrong); - } - let mut powers = BTreeMap::new(); - // GHASH key is our share H^1 - powers.insert(1, ghash_key_share); - let max_odd_power = find_max_odd_power(blocks.len() as u16); - powers = square_all(&powers, blocks.len() as u16); - - // strategy1 and startegy2 are only relevant for the Block Aggregation method. - // They show what existing shares of the powers of H (H is the GHASH key) we - // will be multiplying (value[0] and value[1]) to obtain other odd shares (). - // Max sequential odd share that we can obtain on first round of - // communication is 19. We already have 1) shares of H^1, H^2, H^3 from - // the Client Finished message and 2) squares of those 3 shares. - // Note that "sequential" is a keyword here. We can't obtain 21 but we - // indeed can obtain 25==24+1, 33==32+1 etc. However with 21 missing, - // even if we have 25,33,etc, there will be a gap and we will not be able - // to obtain all the needed shares by Block Aggregation. - - // We request OT for each share in each pair of the strategy, i.e. for - // shares: 4,1,4,3,8,1, etc. Even though it would be possible to introduce - // optimizations in order to avoid requesting OT for the same share more - // than once, that would only save us ~2000 OT instances at the cost of - // complicating the code. - - let strategy1: BTreeMap = BTreeMap::from([ - (5, [4, 1]), - (7, [4, 3]), - (9, [8, 1]), - (11, [8, 3]), - (13, [12, 1]), - (15, [12, 3]), - (17, [16, 1]), - (19, [16, 3]), - ]); - let strategy2: BTreeMap = BTreeMap::from([ - (21, [17, 4]), - (23, [17, 6]), - (25, [17, 8]), - (27, [19, 8]), - (29, [17, 12]), - (31, [19, 12]), - (33, [17, 16]), - (35, [19, 16]), - ]); - - Ok(Self { - max_odd_power, - blocks, - powers, - strategies: [strategy1, strategy2], - temp_share: Some(0u128), - }) - } - - /// Returns the amount of Oblivious Transfer instances needed to complete - /// the protocol. The purpose is to inform the OT layer. - pub fn calculate_ot_count(&mut self) -> usize { - let mut powers: BTreeMap = BTreeMap::new(); - // only 2 2PC multiplications in round 1 - let r1count = 2; - powers.insert(1, 0u128); - powers.insert(2, 0u128); - powers.insert(3, 0u128); - // since we just added a few new shares of powers, we need to square them - self.powers = square_all(&powers, self.blocks.len() as u16); - - // merge 2 strategies maps into 1 - let strategy: BTreeMap = self.strategies[0] - .clone() - .into_iter() - .chain(self.strategies[1].clone()) - .collect(); - // number of multiplications in rounds 2 and 3 - let mut r2and3count = 0; - for (key, _) in strategy.iter() { - if *key > self.max_odd_power { - break; - }; - powers.insert(*key as u16, 0u128); - r2and3count += 2; - } - // since we just added a few new shares of powers, we need to square them - powers = square_all(&powers, self.blocks.len() as u16); - - let mut aggregated: BTreeMap = BTreeMap::new(); - for i in 1..self.blocks.len() + 1 { - if powers.get(&(i as u16)) != None { - continue; - } - let (small, _) = find_sum(&powers, i as u16); - // initialize the value if it doesn't exist - if aggregated.get(&small) == None { - aggregated.insert(small, 0u128); - } - } - let r4count = aggregated.len() * 2; - // each multiplication requires 128 instances of Oblivious Transfer - (r1count + r2and3count + r4count) * 128 - } - - pub fn export_powers(&mut self) -> BTreeMap { - self.powers.clone() - } - - pub fn is_round2_needed(&self) -> bool { - // after round 1 we will have consecutive powers 1,2,3 which is enough - // to compute GHASH for 19 blocks with block aggregation. - self.blocks.len() > 19 - } - - pub fn is_round3_needed(&self) -> bool { - // after round 2 we will have a max of up to 19 consequitive odd powers - // which allows us to get 339 powers with block aggregation, see max_htable - // in utils::find_max_odd_power() - self.blocks.len() > 339 - } - - pub fn is_only_1_round(&self) -> bool { - // block agregation is always used except for very small block count - // where powers from round 1 are sufficient to perform direct multiplication - // of blocks by powers - self.blocks.len() <= 4 - } -} diff --git a/tls-2pc-core/src/ghash/core.rs b/tls-2pc-core/src/ghash/core.rs new file mode 100644 index 000000000..389daa669 --- /dev/null +++ b/tls-2pc-core/src/ghash/core.rs @@ -0,0 +1,129 @@ +use super::{ + compute_missing_mul_shares, compute_new_add_shares, mul, + state::{Finalized, Init, Intermediate, State}, + GhashError, +}; + +/// The core logic for our 2PC Ghash implementation +/// +/// `GhashCore` will do all the necessary computation +pub struct GhashCore { + /// Inner state + state: T, + /// Maximum number of message blocks we want to authenticate + max_message_length: usize, +} + +impl GhashCore { + /// Create a new `GhashCore` + /// + /// * `hashkey` - This is an additive sharing of `H`, which is the AES-encrypted 0 block + /// * `max_message_length` - Determines the maximum number of 128-bit message blocks we want to + /// authenticate + pub fn new(hashkey: u128, max_message_length: usize) -> Result { + if max_message_length == 0 { + return Err(GhashError::ZeroHashkeyPower); + } + + Ok(Self { + state: Init { add_share: hashkey }, + max_message_length, + }) + } + + /// Returns the original hashkey + /// + /// This is an additive sharing of `H` + pub fn h_additive(&self) -> u128 { + self.state.add_share + } + + /// Transform `self` into a `GhashCore`, holding multiplicative shares of + /// powers of `H` + /// + /// Converts `H` into `H`, `H^3`, `H^5`, ... depending on `self.max_message_length` + pub fn compute_odd_mul_powers(self, mul_share: u128) -> GhashCore { + let mut hashkey_powers = vec![mul_share]; + + compute_missing_mul_shares(&mut hashkey_powers, self.max_message_length); + + GhashCore { + state: Intermediate { + odd_mul_shares: hashkey_powers, + cached_add_shares: vec![], + }, + max_message_length: self.max_message_length, + } + } +} + +impl GhashCore { + /// Return odd multiplicative shares of the hashkey + /// + /// Takes into account cached additive shares, so that + /// multiplicative ones for which already an additive one + /// exists, are not returned. + pub fn odd_mul_shares(&self) -> Vec { + // If we already have some cached additive sharings, we do not need to compute new powers. + // So we compute an offset to ignore them. We divide by 2 because + // `self.state.cached_add_shares` contain even and odd powers, while + // `self.state.odd_mul_shares` only have odd powers. + let offset = self.state.cached_add_shares.len() / 2; + + self.state.odd_mul_shares[offset..].to_vec() + } + + /// Adds new additive shares of hashkey powers by also computing the even ones + /// and transforms `self` into a `GhashCore` + pub fn add_new_add_shares(mut self, new_additive_odd_shares: &[u128]) -> GhashCore { + compute_new_add_shares(new_additive_odd_shares, &mut self.state.cached_add_shares); + + GhashCore { + state: Finalized { + add_shares: self.state.cached_add_shares, + odd_mul_shares: self.state.odd_mul_shares, + }, + max_message_length: self.max_message_length, + } + } +} + +impl GhashCore { + /// Generate the GHASH output + /// + /// Computes the 2PC additive share of the GHASH output + pub fn ghash_output(&self, message: &[u128]) -> Result { + if message.len() > self.max_message_length { + return Err(GhashError::InvalidMessageLength); + } + let offset = self.state.add_shares.len() - message.len(); + Ok(message + .iter() + .zip(self.state.add_shares.iter().rev().skip(offset)) + .fold(0, |acc, (block, share)| acc ^ mul(*block, *share))) + } + + /// Change the maximum hashkey power + /// + /// If we want to create a GHASH output for a new message, which is longer than the old one, we need + /// to compute the missing powers of `H`. + pub fn change_max_hashkey(self, new_highest_hashkey_power: usize) -> GhashCore { + let mut hashkey_powers = self.state.odd_mul_shares; + compute_missing_mul_shares(&mut hashkey_powers, new_highest_hashkey_power); + + GhashCore { + state: Intermediate { + odd_mul_shares: hashkey_powers, + cached_add_shares: self.state.add_shares, + }, + max_message_length: new_highest_hashkey_power, + } + } +} + +#[cfg(test)] +impl GhashCore { + pub fn state(&self) -> &T { + &self.state + } +} diff --git a/tls-2pc-core/src/ghash/errors.rs b/tls-2pc-core/src/ghash/errors.rs deleted file mode 100644 index 8e5017f0d..000000000 --- a/tls-2pc-core/src/ghash/errors.rs +++ /dev/null @@ -1,12 +0,0 @@ -/// Errors that may occur when using ghash module -#[derive(Debug, thiserror::Error)] -pub enum GhashError { - #[error("Message was received out of order")] - OutOfOrder, - #[error("The other party sent data of wrong size")] - DataLengthWrong, - #[error("Tried to pass unsupported block count")] - BlockCountWrong, - #[error("Tried to finalize before the protocol was complete")] - FinalizeCalledTooEarly, -} diff --git a/tls-2pc-core/src/ghash/master.rs b/tls-2pc-core/src/ghash/master.rs deleted file mode 100644 index 46a0893b8..000000000 --- a/tls-2pc-core/src/ghash/master.rs +++ /dev/null @@ -1,247 +0,0 @@ -//! Implements the GHASH Master. This is the party which holds the Y value of -//! block multiplication. Master acts as the receiver of the Oblivious -//! Transfer and receives Slaves's masked X table entries obliviously for each -//! bit of Y. -use super::{ - errors::*, - utils::{ - block_aggregation, block_aggregation_bits, block_mult, flat_to_chunks, - multiply_powers_and_blocks, square_all, xor_sum, - }, - MasterCore, -}; -use crate::ghash::{common::GhashCommon, MXTable, YBits}; -use std::collections::BTreeMap; -use utils::iter::u8vec_to_boolvec; - -#[derive(PartialEq)] -pub enum MasterState { - Initialized, - // There may be 1, 2, 3 or 4 rounds depending on GHASH block count - RoundSent(usize), - RoundReceived(usize), - Complete, -} - -pub struct GhashMaster { - c: GhashCommon, - state: MasterState, - // is_last_round will be set to true by next_request() to indicate that - // after the response is received the state must be set to Complete - is_last_round: bool, -} - -impl MasterCore for GhashMaster { - fn next_request(&mut self) -> Result, GhashError> { - let retval; - let is_complete; - match self.state { - MasterState::Initialized => { - self.state = MasterState::RoundSent(1); - retval = self.get_ybits_for_round1().concat(); - is_complete = self.is_only_1_round(); - } - MasterState::RoundReceived(1) => { - if self.is_round2_needed() { - self.state = MasterState::RoundSent(2); - retval = self.get_ybits_for_round(2).concat(); - is_complete = false; - } else { - // rounds 2 and 3 will be skipped - self.state = MasterState::RoundSent(4); - retval = self.get_ybits_for_block_aggr().concat(); - is_complete = true; - } - } - MasterState::RoundReceived(2) => { - if self.is_round3_needed() { - self.state = MasterState::RoundSent(3); - retval = self.get_ybits_for_round(3).concat(); - is_complete = false; - } else { - // round 3 will be skipped - self.state = MasterState::RoundSent(4); - retval = self.get_ybits_for_block_aggr().concat(); - is_complete = true; - } - } - MasterState::RoundReceived(3) => { - self.state = MasterState::RoundSent(4); - retval = self.get_ybits_for_block_aggr().concat(); - is_complete = true; - } - _ => { - return Err(GhashError::OutOfOrder); - } - } - if is_complete { - self.is_last_round = true; - } - Ok(retval) - } - - fn process_response(&mut self, response: &Vec) -> Result<(), GhashError> { - if response.len() % 128 != 0 { - return Err(GhashError::DataLengthWrong); - } - let mxtables = flat_to_chunks(response, 128); - match self.state { - MasterState::RoundSent(1) => { - self.state = MasterState::RoundReceived(1); - self.process_mxtables_for_round1(&mxtables); - } - MasterState::RoundSent(2) => { - self.state = MasterState::RoundReceived(2); - self.process_mxtables_for_round(&mxtables, 2); - } - MasterState::RoundSent(3) => { - self.state = MasterState::RoundReceived(3); - self.process_mxtables_for_round(&mxtables, 3); - } - MasterState::RoundSent(4) => { - self.state = MasterState::RoundReceived(4); - self.process_mxtables_for_block_aggr(&mxtables); - } - _ => { - return Err(GhashError::OutOfOrder); - } - } - if self.is_last_round { - if self.state != MasterState::RoundReceived(4) { - // if the last round was not round 4 (i.e. there was no block - // aggregation), then we compute GHASH directly - self.c.temp_share = - Some(multiply_powers_and_blocks(&self.c.powers, &self.c.blocks)); - } - self.state = MasterState::Complete; - } - Ok(()) - } - - fn is_complete(&mut self) -> bool { - self.state == MasterState::Complete - } - - fn finalize(&mut self) -> Result { - if self.state != MasterState::Complete { - return Err(GhashError::FinalizeCalledTooEarly); - } - Ok(self.c.temp_share.unwrap()) - } - - /// Returns the amount of Oblivious Transfer instances needed to complete - /// the protocol. The purpose is to inform the OT layer. - fn calculate_ot_count(&mut self) -> usize { - self.c.calculate_ot_count() - } - - fn export_powers(&mut self) -> BTreeMap { - self.c.export_powers() - } -} - -impl GhashMaster { - pub fn new(ghash_key_share: u128, blocks: Vec) -> Result { - let common = GhashCommon::new(ghash_key_share, blocks)?; - Ok(Self { - c: common, - state: MasterState::Initialized, - is_last_round: false, - }) - } - - fn is_round2_needed(&self) -> bool { - self.c.is_round2_needed() - } - - fn is_round3_needed(&self) -> bool { - self.c.is_round3_needed() - } - - fn is_only_1_round(&self) -> bool { - self.c.is_only_1_round() - } - - /// Returns Y bits to compute H^3. - fn get_ybits_for_round1(&mut self) -> Vec { - vec![ - u8vec_to_boolvec(&self.c.powers[&1].to_be_bytes()), - u8vec_to_boolvec(&self.c.powers[&2].to_be_bytes()), - ] - } - - /// Takes masked X tables and computes our share of H^3. - fn process_mxtables_for_round1(&mut self, mxtables: &Vec) { - // the XOR sum of all masked xtables' values plus H^1*H^2 is our share of H^3 - self.c.powers.insert( - 3, - xor_sum(&mxtables[0]) - ^ xor_sum(&mxtables[1]) - ^ block_mult(self.c.powers[&1], self.c.powers[&2]), - ); - // since we just added a new share of powers, we need to square them - self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16); - } - - // Returns Y bits for a given round of communication. - fn get_ybits_for_round(&mut self, round_no: u8) -> Vec { - assert!(round_no == 2 || round_no == 3); - let mut bits: Vec = Vec::new(); - for (key, value) in self.c.strategies[(round_no - 2) as usize].iter() { - if *key > self.c.max_odd_power { - break; - } - bits.push(u8vec_to_boolvec( - &self.c.powers[&(value[0] as u16)].to_be_bytes(), - )); - bits.push(u8vec_to_boolvec( - &self.c.powers[&(value[1] as u16)].to_be_bytes(), - )); - } - bits - } - - /// Processes masked X tables for a given round of communication. - fn process_mxtables_for_round(&mut self, mxtables: &Vec, round_no: u8) { - assert!(round_no == 2 || round_no == 3); - for (count, (power, factors)) in self.c.strategies[(round_no - 2) as usize] - .iter() - .enumerate() - { - if *power > self.c.max_odd_power { - // for every key in the strategy which we processed, there - // must have been 2 masked xtables - assert!(count * 2 == mxtables.len()); - break; - } - // the XOR sum of 2 masked xtables' values plus the locally computed - // term is our share of power - let sum = xor_sum(&mxtables[count * 2]) ^ xor_sum(&mxtables[(count * 2) + 1]); - let local_term = block_mult( - self.c.powers[&(factors[0] as u16)], - self.c.powers[&(factors[1] as u16)], - ); - self.c.powers.insert(*power as u16, sum ^ local_term); - } - // since we just added a few new shares of powers, we need to square them - self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16); - } - - /// Returns Y bits for the block aggregation method. - fn get_ybits_for_block_aggr(&mut self) -> Vec { - let share1 = multiply_powers_and_blocks(&self.c.powers, &self.c.blocks); - let (aggregated, share2) = block_aggregation(&self.c.powers, &self.c.blocks); - let choice_bits = block_aggregation_bits(&self.c.powers, &aggregated); - self.c.temp_share = Some(share1 ^ share2); - choice_bits - } - - /// Processes masked X tables for the block aggregation method. - fn process_mxtables_for_block_aggr(&mut self, mxtables: &Vec) { - let mut share = 0u128; - for table in mxtables.iter() { - share ^= xor_sum(table); - } - self.c.temp_share = Some(self.c.temp_share.unwrap() ^ share); - } -} diff --git a/tls-2pc-core/src/ghash/mod.rs b/tls-2pc-core/src/ghash/mod.rs index 00a5b2e3d..c6ce9b0b3 100644 --- a/tls-2pc-core/src/ghash/mod.rs +++ b/tls-2pc-core/src/ghash/mod.rs @@ -1,295 +1,381 @@ -//! ghash implements a protocol of computing the AES-GCM's GHASH function in a -//! secure two-party computation (2PC) setting using 1-out-of-2 Oblivious -//! Transfer (OT). The parties start with their secret XOR shares of H (the -//! GHASH key) and at the end each gets their XOR share of the GHASH output. -//! The method is decribed here: -//! (https://tlsnotary.org/how_it_works#section4). - -//! As an illustration, let's say that S has his shares H1_s and H2_s and R -//! has her shares H1_r and H2_r. They need to compute shares of H3. -//! H3 = (H1_s + H1_r)*(H2_s + H2_r) = H1_s*H2_s + H1_s*H2_r + H1_r*H2_s + -//! H1_r*H2_r. Term 1 can be computed by S locally and term 4 can be -//! computed by R locally. Only terms 2 and 3 will be computed using -//! GHASH 2PC. R will obliviously request values for bits of H1_r and H2_r. -//! The XOR sum of all values which S will send back plus H1_r*H2_r will -//! become R's share of H3. +//! This module implements the AES-GCM's GHASH function in a secure two-party computation (2PC) +//! setting. The parties start with their secret XOR shares of H (the GHASH key) and at the end +//! each gets their XOR share of the GHASH output. The method is described here: +//! . //! -//! When performing block multiplication in 2PC, Master holds the Y value and -//! Slave holds the X value. The Slave then computes the X table and masks it. +//! At first we will convert the XOR (additive) share of `H`, into a multiplicative share. This +//! allows us to compute all the necessary powers of `H^n` locally. Note, that it is only required +//! to compute the odd multiplicative powers, because of free squaring. Then each of these +//! multiplicative shares will be converted back into additive shares. The even additive shares can +//! then locally be built by using the odd ones. This way, we can batch nearly all oblivious +//! transfers and reduce the round complexity of the protocol. +//! +//! On the whole, we need a single additive-to-multiplicative (A2M) and `n/2`, where `n` is the +//! number of blocks of message, multiplicative-to-additive (M2A) conversions. Finally, having +//! additive shares of `H^n` for all needed `n`, we can compute an additive share of the GHASH +//! output. -mod common; -pub mod errors; -pub mod master; -pub mod slave; -mod utils; +/// Contains the core logic for ghash +mod core; -use errors::*; -use std::collections::BTreeMap; +/// Contains the different states +pub mod state; -/// MXTableFull is masked XTable which Slave has at the beginning of OT. -/// MXTableFull must not be revealed to Master. -type MXTableFull = Vec<[u128; 2]>; -/// MXTable is a masked x table which Master will end up having after OT. -type MXTable = Vec; -/// YBits are Master's bits of Y in big-endian. Based on these bits -/// Master will send MXTable via OT. -/// The convention for the returned Y bits: -/// A) powers are in an ascending order: first powers[1], then powers[2] etc. -/// B) bits of each power are in big-endian. -type YBits = Vec; +pub use crate::ghash::core::GhashCore; +use share_conversion_core::gf2_128::{compute_product_repeated, mul}; +use thiserror::Error; -pub trait MasterCore { - /// Returns choice bits for Oblivious Transfer. - /// While is_complete() returns false, next_request() must be called - /// followed by process_response(). - fn next_request(&mut self) -> Result, GhashError>; - - /// process_response() will be invoked by the Oblivious Transfer impl. It - /// receives masked X tables acc. to our choice bits in next_request(). - fn process_response(&mut self, response: &Vec) -> Result<(), GhashError>; - - /// Returns true when the protocol is complete. - fn is_complete(&mut self) -> bool; - - /// Returns our GHASH share. - fn finalize(&mut self) -> Result; - - /// Returns the amount of Oblivious Transfer instances needed to complete - /// the protocol. The purpose is to inform the OT layer. - fn calculate_ot_count(&mut self) -> usize; - - /// Exports powers of the GHASH key obtained at the current stage of the - /// protocol. - fn export_powers(&mut self) -> BTreeMap; +#[derive(Debug, Error)] +pub enum GhashError { + #[error("Invalid maximum hashkey power")] + ZeroHashkeyPower, + #[error("Message too long")] + InvalidMessageLength, } -pub trait SlaveCore { - /// Returns the full masked X table which must NOT be passed to Master. It - /// must be consumed by the Oblivious Transfer impl. - fn process_request(&mut self) -> Result, GhashError>; +/// Computes missing odd multiplicative shares of the hashkey powers +/// +/// Checks if depending on the number of `needed` shares, we need more odd multiplicative shares and +/// computes them. Notice that we only need odd multiplicative shares for the OT, because we can +/// derive even additive shares from odd additive shares, which we call free squaring. +/// +/// * `present_odd_mul_shares` - multiplicative odd shares already present +/// * `needed` - how many powers we need including odd and even +fn compute_missing_mul_shares(present_odd_mul_shares: &mut Vec, needed: usize) { + // divide by 2 and round up + let needed_odd_powers: usize = needed / 2 + (needed & 1); + let present_odd_len = present_odd_mul_shares.len(); - /// Returns true when the protocol is complete. - fn is_complete(&mut self) -> bool; + if needed_odd_powers > present_odd_len { + let h_squared = mul(present_odd_mul_shares[0], present_odd_mul_shares[0]); + compute_product_repeated( + present_odd_mul_shares, + h_squared, + needed_odd_powers - present_odd_len, + ); + } +} - /// Returns our GHASH share. - fn finalize(&mut self) -> Result; +/// Computes new even (additive) shares from new odd (additive) shares and saves both the new odd shares +/// and the new even shares. +/// +/// This function implements the derivation of even additive shares from odd additive shares, +/// which we refer to as free squaring. Every additive share of an even power of +/// `H` can be computed without an OT interaction by squaring the corresponding additive share +/// of an odd power of `H`, e.g. if we have a share of H^3, we can derive the share of H^6 by doing +/// (H^3)^2 +/// +/// * `new_add_odd_shares` - new odd additive shares we got as a result of doing an OT on odd +/// multiplicative shares +/// * `add_shares` - all additive shares (even and odd) we already have. This is a mutable +/// reference to cached_add_shares in [crate::ghash::state::Intermediate] +fn compute_new_add_shares(new_add_odd_shares: &[u128], add_shares: &mut Vec) { + for (odd_share, current_odd_power) in new_add_odd_shares + .iter() + .zip((add_shares.len() + 1..).step_by(2)) + { + // `add_shares` always have an even number of shares so we simply add the next odd share + add_shares.push(*odd_share); - /// Returns the amount of Oblivious Transfer instances needed to complete - /// the protocol. The purpose is to inform the OT layer. - fn calculate_ot_count(&mut self) -> usize; - - /// Exports powers of the GHASH key obtained at the current stage of the - /// protocol. - fn export_powers(&mut self) -> BTreeMap; + // now we need to compute the next even share and add it + // note that the n-th index corresponds to the (n+1)-th power, e.g. add_shares[4] + // is the share of H^5 + let mut base_share = add_shares[current_odd_power / 2]; + base_share = mul(base_share, base_share); + add_shares.push(base_share); + } } #[cfg(test)] mod tests { - use super::{ - errors::GhashError, master::GhashMaster, slave::GhashSlave, utils::block_mult, MasterCore, - SlaveCore, - }; use ghash_rc::{ universal_hash::{NewUniversalHash, UniversalHash}, GHash, }; - use rand::{prelude::ThreadRng, thread_rng, Rng}; - use std::convert::TryInto; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha12Rng; + use share_conversion_core::gf2_128::inverse; + + use super::{ + compute_missing_mul_shares, compute_new_add_shares, compute_product_repeated, mul, + state::{Finalized, Intermediate}, + GhashCore, + }; #[test] - // test only round 1 - fn test_round1() { - let block_count = 3; - let (h, mut slave, mut master, blocks) = ghash_setup(block_count); - run_round(&mut slave, &mut master).unwrap(); - let ghash = finalize(&mut slave, &mut master); - assert_eq!(ghash, rust_crypto_ghash(h, &blocks)); - assert!(master.is_complete()); - assert!(slave.is_complete()); + fn test_ghash_product_sharing() { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + + // The Ghash key + let h: u128 = rng.gen(); + let message = gen_u128_vec(); + let message_len = message.len(); + let number_of_powers_needed: usize = message_len / 2 + (message_len & 1); + + let (sender, receiver) = setup_ghash_to_intermediate_state(h, message_len); + + let mut powers_h = vec![h]; + compute_product_repeated(&mut powers_h, mul(h, h), number_of_powers_needed); + + // Length check + assert_eq!(sender.state().odd_mul_shares.len(), number_of_powers_needed); + assert_eq!( + receiver.state().odd_mul_shares.len(), + number_of_powers_needed + ); + + // Product check + for (k, (sender_share, receiver_share)) in std::iter::zip( + sender.state().odd_mul_shares.iter(), + receiver.state().odd_mul_shares.iter(), + ) + .enumerate() + { + assert_eq!(mul(*sender_share, *receiver_share), powers_h[k]); + } } #[test] - // test state after rounds 1,2 (but before block aggregation) - fn test_round12_before_block_aggregation() { - let block_count = 30; - let (h, mut slave, mut master, _blocks) = ghash_setup(block_count); - run_round(&mut slave, &mut master).unwrap(); - run_round(&mut slave, &mut master).unwrap(); + fn test_ghash_sum_sharing() { + let mut rng = ChaCha12Rng::from_seed([0; 32]); - let s_powers = slave.export_powers(); - let r_powers = master.export_powers(); - let all_s_keys: Vec = s_powers.keys().cloned().collect(); - let all_r_keys: Vec = r_powers.keys().cloned().collect(); - let expected_keys = vec![1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28]; - let exp_powers = compute_expected_powers(h, block_count as u16); + // The Ghash key + let h: u128 = rng.gen(); + let message = gen_u128_vec(); + let message_len = message.len(); - assert_eq!(all_s_keys, expected_keys); - assert_eq!(all_r_keys, expected_keys); - // compare shares of powers against expected powers - for key in expected_keys.iter() { + let (sender, receiver) = setup_ghash_to_intermediate_state(h, message_len); + let (sender, receiver) = ghash_to_finalized(sender, receiver); + + let mut powers_h = vec![h]; + compute_product_repeated(&mut powers_h, h, message_len); + + // Length check + assert_eq!( + sender.state().add_shares.len(), + message_len + (message_len & 1) + ); + assert_eq!( + receiver.state().add_shares.len(), + message_len + (message_len & 1) + ); + + // Sum check + for k in 0..message_len { assert_eq!( - exp_powers[*key as usize], - *s_powers.get(key).unwrap() ^ *r_powers.get(key).unwrap() + sender.state().add_shares[k] ^ receiver.state().add_shares[k], + powers_h[k] ); } - assert!(!master.is_complete()); - assert!(!slave.is_complete()); } #[test] - // test rounds 1,2,4 - fn test_round124() { - let block_count = 30; - let (h, mut slave, mut master, blocks) = ghash_setup(block_count); - run_round(&mut slave, &mut master).unwrap(); - run_round(&mut slave, &mut master).unwrap(); - run_round(&mut slave, &mut master).unwrap(); - let ghash = finalize(&mut slave, &mut master); - assert_eq!(ghash, rust_crypto_ghash(h, &blocks)); - assert!(master.is_complete()); - assert!(slave.is_complete()); - } + fn test_ghash_output() { + let mut rng = ChaCha12Rng::from_seed([0; 32]); - #[test] - // test rounds 1,2,3,4 - fn test_round1234() { - let block_count = 340; - let (h, mut slave, mut master, blocks) = ghash_setup(block_count); - run_round(&mut slave, &mut master).unwrap(); - run_round(&mut slave, &mut master).unwrap(); - run_round(&mut slave, &mut master).unwrap(); - run_round(&mut slave, &mut master).unwrap(); - let ghash = finalize(&mut slave, &mut master); - assert_eq!(ghash, rust_crypto_ghash(h, &blocks)); - assert!(master.is_complete()); - assert!(slave.is_complete()); - } - - #[test] - // test export_powers() after round 1 - fn test_export_powers() { - let block_count = 340; - let (h, mut slave, mut master, blocks) = ghash_setup(block_count); - run_round(&mut slave, &mut master).unwrap(); - let powers_s = slave.export_powers(); - let powers_r = master.export_powers(); - // we only have 4 consecutive powers 1,2,3,4 after round 1 - // we compute ghash for only the first 4 blocks - let mut ghash_share_s = 0u128; - for i in 0..4 { - ghash_share_s ^= block_mult(*powers_s.get(&(4 - i)).unwrap(), blocks[i as usize]); - } - let mut ghash_share_r = 0u128; - for i in 0..4 { - ghash_share_r ^= block_mult(*powers_r.get(&(4 - i)).unwrap(), blocks[i as usize]); - } - let ghash = ghash_share_s ^ ghash_share_r; - assert_eq!(ghash, rust_crypto_ghash(h, &blocks[0..4].to_vec())); - assert!(!master.is_complete()); - assert!(!slave.is_complete()); - } - - #[test] - // test OT count against hard-coded values and also double-check against - // the actual amount of Y bits which the Master requested. - fn test_calculate_ot_count() { - let block_counts = vec![3, 19, 200, 1026]; - // expected amount of 2PC block multiplications - let expected = vec![2, 8, 44, 96]; - - for i in 0..block_counts.len() { - let block_count = block_counts[i]; - let (_, mut slave, mut master, _) = ghash_setup(block_count); - let mut ybits_count = 0; - while !master.is_complete() { - let req = master.next_request().unwrap(); - ybits_count += req.len(); - let resp = slave.process_request().unwrap(); - let masked_xtable = simulate_ot(&req, &resp); - master.process_response(&masked_xtable).unwrap(); - } - let ot_count = master.calculate_ot_count(); - assert!(ot_count == expected[i] * 128); - assert!(ot_count == ybits_count); - } - } - - fn finalize(sender: &mut GhashSlave, receiver: &mut GhashMaster) -> u128 { - let sender_ghash_share = sender.finalize().unwrap(); - let receiver_ghash_share = receiver.finalize().unwrap(); - receiver_ghash_share ^ sender_ghash_share - } - - fn ghash_setup(block_count: usize) -> (u128, GhashSlave, GhashMaster, Vec) { - let mut rng = thread_rng(); - // h is ghash key + // The Ghash key let h: u128 = rng.gen(); - // h_s is sender's XOR share of h - let h_s: u128 = rng.gen(); - // h_r is receiver's XOR share of h - let h_r: u128 = h ^ h_s; + let message = gen_u128_vec(); - let blocks: Vec = random_blocks(block_count); - let sender = GhashSlave::new(rng, h_s, blocks.clone()).unwrap(); - let receiver = GhashMaster::new(h_r, blocks.clone()).unwrap(); - (h, sender, receiver, blocks) + let (sender, receiver) = setup_ghash_to_intermediate_state(h, message.len()); + let (sender, receiver) = ghash_to_finalized(sender, receiver); + + assert_eq!( + sender.ghash_output(&message).unwrap() ^ receiver.ghash_output(&message).unwrap(), + ghash_reference_impl(h, message) + ); } - fn random_blocks(block_count: usize) -> Vec { - let mut rng = thread_rng(); - let mut blocks: Vec = Vec::new(); - for _i in 0..block_count { - blocks.push(rng.gen()); + #[test] + fn test_ghash_change_message_short() { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + + // The Ghash key + let h: u128 = rng.gen(); + let message = gen_u128_vec(); + + let (sender, receiver) = setup_ghash_to_intermediate_state(h, message.len()); + let (sender, receiver) = ghash_to_finalized(sender, receiver); + + let mut message_short: Vec = vec![0; message.len() / 2]; + message_short.iter_mut().for_each(|x| *x = rng.gen()); + + let (sender, receiver) = ( + sender.change_max_hashkey(message_short.len()), + receiver.change_max_hashkey(message_short.len()), + ); + + let (sender, receiver) = ghash_to_finalized(sender, receiver); + + assert_eq!( + sender.ghash_output(&message_short).unwrap() + ^ receiver.ghash_output(&message_short).unwrap(), + ghash_reference_impl(h, message_short) + ); + } + + #[test] + fn test_ghash_change_message_long() { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + + // The Ghash key + let h: u128 = rng.gen(); + let message = gen_u128_vec(); + + let (sender, receiver) = setup_ghash_to_intermediate_state(h, message.len()); + let (sender, receiver) = ghash_to_finalized(sender, receiver); + + let mut message_long: Vec = vec![0; 2 * message.len()]; + message_long.iter_mut().for_each(|x| *x = rng.gen()); + + let (sender, receiver) = ( + sender.change_max_hashkey(message_long.len()), + receiver.change_max_hashkey(message_long.len()), + ); + + let (sender, receiver) = ghash_to_finalized(sender, receiver); + + assert_eq!( + sender.ghash_output(&message_long).unwrap() + ^ receiver.ghash_output(&message_long).unwrap(), + ghash_reference_impl(h, message_long) + ); + } + + #[test] + fn test_compute_missing_mul_shares() { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + let h: u128 = rng.gen(); + let mut powers: Vec = vec![h]; + compute_product_repeated(&mut powers, mul(h, h), rng.gen_range(16..128)); + + let powers_len = powers.len(); + let needed = rng.gen_range(1..256); + + compute_missing_mul_shares(&mut powers, needed); + + // Check length + if needed / 2 + (needed & 1) <= powers_len { + assert_eq!(powers.len(), powers_len); + } else { + assert_eq!(powers.len(), needed / 2 + (needed & 1)) + } + + // Check shares + let first = *powers.first().unwrap(); + let factor = mul(first, first); + + let mut expected = first; + for share in powers.iter() { + assert_eq!(*share, expected); + expected = mul(expected, factor); } - blocks } - // compute GHASH using RustCrypto's ghash - fn rust_crypto_ghash(h: u128, blocks: &Vec) -> u128 { + #[test] + fn test_compute_new_add_shares() { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + + let new_add_odd_shares: Vec = gen_u128_vec(); + let mut add_shares: Vec = gen_u128_vec(); + + // We have the invariant that len of add_shares is always even + if add_shares.len() & 1 == 1 { + add_shares.push(rng.gen()); + } + + let original_len = add_shares.len(); + + compute_new_add_shares(&new_add_odd_shares, &mut add_shares); + + // Check new length + assert_eq!( + add_shares.len(), + original_len + 2 * new_add_odd_shares.len() + ); + + // Check odd shares + for (k, l) in (original_len..add_shares.len()) + .step_by(2) + .zip(0..original_len) + { + assert_eq!(add_shares[k], new_add_odd_shares[l]); + } + + // Check even shares + for k in (original_len + 1..add_shares.len()).step_by(2) { + assert_eq!(add_shares[k], mul(add_shares[k / 2], add_shares[k / 2])); + } + } + + fn gen_u128_vec() -> Vec { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + + // Sample some message + let message_len: usize = rng.gen_range(16..128); + let mut message: Vec = vec![0_u128; message_len]; + message.iter_mut().for_each(|x| *x = rng.gen()); + message + } + + fn ghash_reference_impl(h: u128, message: Vec) -> u128 { let mut ghash = GHash::new(&h.to_be_bytes().into()); - for block in blocks.iter() { - ghash.update(&block.to_be_bytes().into()); + for el in message { + ghash.update(&el.to_be_bytes().into()); } - let b = ghash.finalize().into_bytes(); - u128::from_be_bytes(b.as_slice().try_into().unwrap()) + let ghash_output = ghash.finalize(); + u128::from_be_bytes(ghash_output.into_bytes().try_into().unwrap()) } - // prepare the expected powers of h by recursively multiplying h to - // itself - fn compute_expected_powers(h: u128, max: u16) -> Vec { - // prepare the expected powers of h by recursively multiplying h to - // itself - let mut powers: Vec = vec![0u128; (max + 1) as usize]; - powers[1] = h; - let mut prev_power = h; - for i in 2..((max + 1) as usize) { - powers[i] = block_mult(prev_power, h); - prev_power = powers[i]; - } - powers + fn setup_ghash_to_intermediate_state( + hashkey: u128, + max_hashkey_power: usize, + ) -> (GhashCore, GhashCore) { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + + // The additive sharings of the Ghash key to begin with + let h1_additive: u128 = rng.gen(); + let h2_additive: u128 = hashkey ^ h1_additive; + + // Create a multiplicative sharing + let h1_multiplicative: u128 = rng.gen(); + let h2_multiplicative: u128 = mul(hashkey, inverse(h1_multiplicative)); + + let sender = GhashCore::new(h1_additive, max_hashkey_power).unwrap(); + let receiver = GhashCore::new(h2_additive, max_hashkey_power).unwrap(); + + let (sender, receiver) = ( + sender.compute_odd_mul_powers(h1_multiplicative), + receiver.compute_odd_mul_powers(h2_multiplicative), + ); + + (sender, receiver) } - // run_round runs the next round - fn run_round( - sender: &mut GhashSlave, - receiver: &mut GhashMaster, - ) -> Result<(), GhashError> { - let receiver_bits = receiver.next_request()?; - let masked_xtable_full = sender.process_request()?; - let masked_xtable = simulate_ot(&receiver_bits, &masked_xtable_full); - receiver.process_response(&masked_xtable)?; - Ok(()) + fn ghash_to_finalized( + sender: GhashCore, + receiver: GhashCore, + ) -> (GhashCore, GhashCore) { + let (add_shares_sender, add_shares_receiver) = + m2a(&sender.odd_mul_shares(), &receiver.odd_mul_shares()); + let (sender, receiver) = ( + sender.add_new_add_shares(&add_shares_sender), + receiver.add_new_add_shares(&add_shares_receiver), + ); + (sender, receiver) } - // normally Master will send his bits via OT to get only 1 out of 2 values - // for each row of masked xtable. Here we simulate this OT behaviour. - fn simulate_ot(receiver_bits: &Vec, mxtables_full: &Vec<[u128; 2]>) -> Vec { - assert!(receiver_bits.len() == mxtables_full.len()); - let mut mxtables: Vec = Vec::new(); - for i in 0..mxtables_full.len() { - let choice = receiver_bits[i] as usize; - mxtables.push(mxtables_full[i][choice]); + fn m2a(first: &[u128], second: &[u128]) -> (Vec, Vec) { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + let mut first_out = vec![]; + let mut second_out = vec![]; + for (j, k) in first.iter().zip(second.iter()) { + let product = mul(*j, *k); + let first_summand: u128 = rng.gen(); + let second_summand: u128 = product ^ first_summand; + first_out.push(first_summand); + second_out.push(second_summand); } - mxtables + (first_out, second_out) } } diff --git a/tls-2pc-core/src/ghash/slave.rs b/tls-2pc-core/src/ghash/slave.rs deleted file mode 100644 index c4ffb348e..000000000 --- a/tls-2pc-core/src/ghash/slave.rs +++ /dev/null @@ -1,185 +0,0 @@ -//! Implements the GHASH Slave. This is the party which holds the X value of -//! block multiplication. Slave acts as the sender of the Oblivious -//! Transfer and sends masked x_table entries obliviously for each -//! bit of Y received from the GHASH Master. - -use super::{ - common::GhashCommon, - errors::*, - utils::{ - block_aggregation, block_aggregation_mxtables, block_mult, free_square, masked_xtable, - multiply_powers_and_blocks, square_all, - }, - MXTableFull, SlaveCore, -}; -use rand::{CryptoRng, Rng}; -use std::collections::BTreeMap; - -#[derive(PartialEq)] -pub enum SlaveState { - Initialized, - // There may be from 1 to 4 rounds depending on GHASH block count - RoundReceived(usize), - Complete, -} - -pub struct GhashSlave { - c: GhashCommon, - rng: R, - state: SlaveState, -} - -impl SlaveCore for GhashSlave { - fn process_request(&mut self) -> Result, GhashError> { - let retval; - let is_complete; - match self.state { - SlaveState::Initialized => { - self.state = SlaveState::RoundReceived(1); - retval = self.get_mxtables_for_round1().concat(); - is_complete = self.is_only_1_round(); - } - SlaveState::RoundReceived(1) => { - if self.is_round2_needed() { - self.state = SlaveState::RoundReceived(2); - retval = self.get_mxtables_for_round(2).concat(); - is_complete = false; - } else { - // rounds 2 and 3 will be skipped - self.state = SlaveState::RoundReceived(4); - retval = self.get_mxtables_for_block_aggr().concat(); - is_complete = true; - } - } - SlaveState::RoundReceived(2) => { - if self.is_round3_needed() { - self.state = SlaveState::RoundReceived(3); - retval = self.get_mxtables_for_round(3).concat(); - is_complete = false; - } else { - // round 3 will be skipped - self.state = SlaveState::RoundReceived(4); - retval = self.get_mxtables_for_block_aggr().concat(); - is_complete = true; - } - } - SlaveState::RoundReceived(3) => { - self.state = SlaveState::RoundReceived(4); - retval = self.get_mxtables_for_block_aggr().concat(); - is_complete = true; - } - _ => { - return Err(GhashError::OutOfOrder); - } - } - if is_complete { - if self.state != SlaveState::RoundReceived(4) { - // if the last round was not round 4 (i.e. there was no block - // aggregation), then we compute GHASH directly - self.c.temp_share = - Some(multiply_powers_and_blocks(&self.c.powers, &self.c.blocks)); - } - self.state = SlaveState::Complete; - } - Ok(retval) - } - - fn finalize(&mut self) -> Result { - if self.state != SlaveState::Complete { - return Err(GhashError::FinalizeCalledTooEarly); - } - Ok(self.c.temp_share.unwrap()) - } - - fn is_complete(&mut self) -> bool { - self.state == SlaveState::Complete - } - - fn export_powers(&mut self) -> BTreeMap { - self.c.export_powers() - } - - fn calculate_ot_count(&mut self) -> usize { - self.c.calculate_ot_count() - } -} - -impl GhashSlave { - pub fn new(rng: R, ghash_key_share: u128, blocks: Vec) -> Result { - let c = GhashCommon::new(ghash_key_share, blocks)?; - Ok(Self { - c, - rng, - state: SlaveState::Initialized, - }) - } - - fn is_round2_needed(&self) -> bool { - self.c.is_round2_needed() - } - - fn is_round3_needed(&self) -> bool { - self.c.is_round3_needed() - } - - fn is_only_1_round(&self) -> bool { - self.c.is_only_1_round() - } - - /// Returns the masked X table for round 1. - /// Since the Master (M) sends bits for his powers in ascending order, we need to - /// accomodate that order, i.e. if we need to multiply M's H^1 by - /// our H^2 and then multiply M's H^2 by our H^1, then we return [mxtable - /// for H^2 + mxtable for H^1]. - fn get_mxtables_for_round1(&mut self) -> Vec { - self.c.powers.insert(2, free_square(self.c.powers[&1])); - let (masked1, h3_share1) = masked_xtable(&mut self.rng, self.c.powers[&1]); - let (masked2, h3_share2) = masked_xtable(&mut self.rng, self.c.powers[&2]); - - self.c.powers.insert( - 3, - block_mult(self.c.powers[&1], self.c.powers[&2]) ^ h3_share1 ^ h3_share2, - ); - // since we just added a new share of powers, we need to square them - self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16); - vec![masked2, masked1] - } - - /// Returns masked X tables for either round 2 or round 3. - fn get_mxtables_for_round(&mut self, round_no: u8) -> Vec { - assert!(round_no == 2 || round_no == 3); - let mut all_mxtables: Vec = Vec::new(); - for (key, value) in self.c.strategies[(round_no - 2) as usize].clone().iter() { - if *key > self.c.max_odd_power { - break; - } - // Since Master sends bits in ascending order: factor1 bits, - // factor2 bits, we must return mxtables in descending order: - // factor2 mxtable, factor1 mxtable. - let factor1 = self.c.powers[&(value[0] as u16)]; - let factor2 = self.c.powers[&(value[1] as u16)]; - let (mxtable1, sum1) = masked_xtable(&mut self.rng, factor1); - let (mxtable2, sum2) = masked_xtable(&mut self.rng, factor2); - all_mxtables.push(mxtable2); - all_mxtables.push(mxtable1); - - // our share of power is the locally computed term plus sums - // of masks of each cross-term. - let local_term = block_mult(factor1, factor2); - self.c.powers.insert(*key as u16, local_term ^ sum1 ^ sum2); - } - // since we just added a few new shares of powers, we need to square them - self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16); - all_mxtables - } - - /// Returns masked X tables for the block aggregation method. - fn get_mxtables_for_block_aggr(&mut self) -> Vec { - let share1 = multiply_powers_and_blocks(&self.c.powers, &self.c.blocks); - let (aggregated, share2) = block_aggregation(&self.c.powers, &self.c.blocks); - let (mxtables, share3) = - block_aggregation_mxtables(&mut self.rng, &self.c.powers, &aggregated); - self.c.temp_share = Some(share1 ^ share2 ^ share3); - mxtables - } -} diff --git a/tls-2pc-core/src/ghash/state.rs b/tls-2pc-core/src/ghash/state.rs new file mode 100644 index 000000000..7d3333b14 --- /dev/null +++ b/tls-2pc-core/src/ghash/state.rs @@ -0,0 +1,46 @@ +mod sealed { + pub trait Sealed {} + + impl Sealed for super::Init {} + impl Sealed for super::Intermediate {} + impl Sealed for super::Finalized {} +} + +pub trait State: sealed::Sealed {} + +impl State for Init {} +impl State for Intermediate {} +impl State for Finalized {} + +/// Init state for Ghash protocol +/// +/// This is before any OT has taken place +#[derive(Clone, Debug)] +pub struct Init { + pub(super) add_share: u128, +} + +/// Intermediate state for Ghash protocol +/// +/// This is when the additive share has been converted into a multiplicative share and all the +/// needed powers have been computed +#[derive(Clone, Debug)] +pub struct Intermediate { + pub(super) odd_mul_shares: Vec, + // A vec of all additive shares (even and odd) we already have. + // (In order to simplify the code) the n-th index of the vec corresponds to the additive share + // of the (n+1)-th power of H, e.g. the share of H^1 is located at the 0-th index of the vec + // It always contains an even number of consecutive shares starting from the share of H^1 up to + // the share of H^(cached_add_shares.len()) + pub(super) cached_add_shares: Vec, +} + +/// Final state for Ghash protocol +/// +/// This is when each party can compute a final share of the ghash output, because both now have +/// additive shares of all the powers of `H` +#[derive(Clone, Debug)] +pub struct Finalized { + pub(super) odd_mul_shares: Vec, + pub(super) add_shares: Vec, +} diff --git a/tls-2pc-core/src/ghash/utils.rs b/tls-2pc-core/src/ghash/utils.rs deleted file mode 100644 index 0be819e26..000000000 --- a/tls-2pc-core/src/ghash/utils.rs +++ /dev/null @@ -1,586 +0,0 @@ -use crate::ghash::{MXTableFull, YBits}; -use rand::{CryptoRng, Rng}; -use std::collections::BTreeMap; -use utils::iter::u8vec_to_boolvec; - -/// R is GCM polynomial in little-endian. In hex: "E1000000000000000000000000000000" -const R: u128 = 299076299051606071403356588563077529600; - -/// Galois field multiplication of two 128-bit blocks reduced by the GCM polynomial -pub fn block_mult(mut x: u128, y: u128) -> u128 { - let mut result: u128 = 0; - for i in (0..128).rev() { - result ^= x * ((y >> i) & 1); - x = (x >> 1) ^ ((x & 1) * R); - } - result -} - -/// Returns the squared value. It is called "free" because due to the -/// property of Galois multiplication, squaring can be done locally without -/// the need for 2PC. -pub fn free_square(x: u128) -> u128 { - block_mult(x, x) -} - -/// Performs squaring of each odd power in "powers" up to and including -/// the maximum power "max" and returns an updated map of powers. Squaring -/// will be done recursively if needed, e.g if we have power == 1 and "max" is 22, -/// then 1 will be squared to get power == 2, then 2 -> 4, 4 -> 8, 8 -> 16. -/// Those powers which have already been squared will be skipped. -pub fn square_all(powers: &BTreeMap, max: u16) -> BTreeMap { - let mut new_powers: BTreeMap = BTreeMap::new(); - for (power, value) in powers.iter() { - // The fact the we had earlier computed more powers that we will ever - // need is a sign of a logic error which needs to be investigated. - assert!(*power <= max); - new_powers.insert(*power, *value); - if power % 2 == 0 { - continue; - } - // existing_power is the power for which we have the value. - let mut existing_power = *power; - while existing_power * 2 <= max { - // check if the squaring has already been done, otherwise do it now. - let option = powers.get(&(existing_power * 2)); - let squared_value: u128; - if option == None { - let value_to_square = new_powers.get(&existing_power).unwrap(); - squared_value = free_square(*value_to_square); - } else { - squared_value = *option.unwrap(); - } - new_powers.insert(existing_power * 2, squared_value); - existing_power *= 2; - } - } - new_powers -} - -/// Finds 2 non-equal summands which add up to the needed sum. The -/// first returned summand will be as small as possible. -/// E.g if "summands" keys are 1,2,3,5,6 and "sum_needed" is 8, then -/// the returned value would be (2,6). -pub fn find_sum(summands: &BTreeMap, sum_needed: u16) -> (u16, u16) { - for (i, _) in summands.iter() { - for (j, _) in summands.iter() { - if *j == *i { - continue; - } - if *i + *j == sum_needed { - return (*i, *j); - } - } - } - // Should never get here. We only call find_sum when we know in advance - // that summands will be found. - panic!("summands were not found") -} - -/// Returns the maximum odd power that we'll need to compute GHASH in 2PC -/// using Block Aggregation, where "max" is maximum power for GHASH -/// (i.e it is the amount of GHASH blocks). -pub fn find_max_odd_power(max: u16) -> u8 { - assert!(max <= 1026); - // max_htable's shows how many GHASH blocks can be processed - // with Block Aggregation if we have all the sequential shares - // starting with 1 up to and including . - // e.g. (5, 29) means that if we have shares of H^1,H^2,H^3,H^4,H^5, - // then we can process 29 GHASH blocks. - // max TLS record size of 16KB requires 1026 GHASH blocks - let max_htable: BTreeMap = BTreeMap::from([ - (0, 0), - (3, 19), - (5, 29), - (7, 71), - (9, 89), - (11, 107), - (13, 125), - (15, 271), - (17, 305), - (19, 339), - (21, 373), - (23, 407), - (25, 441), - (27, 475), - (29, 509), - (31, 1023), - (33, 1025), - (35, 1027), - ]); - let mut out = 0u8; - for (key, value) in max_htable.iter() { - if *value >= max { - out = *key; - break; - } - } - out -} - -/// Multiplies GHASH blocks by the corresponding shares of powers of H and -/// returns the sum of all products. If some share is not present, the -/// corresponding block is not multiplied at this stage but it will later -/// participate in block aggregation. -pub fn multiply_powers_and_blocks(powers: &BTreeMap, blocks: &Vec) -> u128 { - let last_key = *powers.iter().last().unwrap().0; - assert!(last_key as usize <= blocks.len()); - let mut sum = 0u128; - for (power, value) in powers.iter() { - // in GHASH, H^1 is multiplied with the last block, H^2 with the second to last - // block, etc. - sum ^= block_mult(*value, blocks[blocks.len() - (*power as usize)]); - } - sum -} - -/// Implements the block aggregation method. -pub fn block_aggregation( - powers: &BTreeMap, - blocks: &Vec, -) -> (BTreeMap, u128) { - let mut ghash_share = 0u128; - let mut aggregated: BTreeMap = BTreeMap::new(); - for i in 1..blocks.len() + 1 { - if powers.get(&(i as u16)) != None { - // we already multiplied the block with this share of power in - // multiply_powers_and_blocks() - continue; - } - // else we found a power of H which we don't have. - let (small, big) = find_sum(&powers, i as u16); - let block = blocks[blocks.len() - i]; - ghash_share ^= block_mult( - block_mult(*powers.get(&small).unwrap(), *powers.get(&big).unwrap()), - block, - ); - // initialize the value if it doesn't exist - if aggregated.get(&small) == None { - aggregated.insert(small, 0u128); - } - // update value - let old_value = *aggregated.get(&small).unwrap(); - aggregated.insert( - small, - old_value ^ block_mult(*powers.get(&big).unwrap(), block), - ); - } - (aggregated, ghash_share) -} - -/// Returns YBits which Master needs to complete Block Aggregation. -pub fn block_aggregation_bits( - powers: &BTreeMap, - aggregated: &BTreeMap, -) -> Vec { - let mut all_bits: Vec = Vec::new(); - for (power, value) in aggregated.iter() { - // Master sends first bits of power then bits of value. Slave sends - // masked x tables in reverse order. - all_bits.push(u8vec_to_boolvec( - &(*powers.get(power).unwrap()).to_be_bytes(), - )); - all_bits.push(u8vec_to_boolvec(&value.to_be_bytes())); - } - all_bits -} - -/// Returns masked X tables which Slave needs to complete Block Aggregation. -pub fn block_aggregation_mxtables( - rng: &mut R, - powers: &BTreeMap, - aggregated: &BTreeMap, -) -> (Vec, u128) { - let mut all_mxtables: Vec = Vec::new(); - let mut sum = 0u128; - for (power, value) in aggregated.iter() { - // Slave sends first masked x table of agregated value then masked x - // table of power value. - let (mxtable1, sum1) = masked_xtable(rng, *value); - let (mxtable2, sum2) = masked_xtable(rng, *powers.get(power).unwrap()); - sum ^= sum1 ^ sum2; - all_mxtables.push(mxtable1); - all_mxtables.push(mxtable2); - } - (all_mxtables, sum) -} - -/// Returns a table of values of x after each of the 128 rounds of blockMult() -fn xtable(mut x: u128) -> Vec { - let mut x_table: Vec = vec![0u128; 128]; - for i in 0..128 { - x_table[i] = x; - x = (x >> 1) ^ ((x & 1) * R); - } - x_table -} - -/// Returns: -/// 1) a masked xTable from which OT response will be constructed and -/// 2) the XOR-sum of all masks which is our share of the block multiplication product -/// For each value of xTable, the masked xTable will contain 2 values: -/// 1) a random mask and -/// 2) the xTable entry masked with the random mask. -pub fn masked_xtable(rng: &mut R, x: u128) -> (MXTableFull, u128) { - let x_table = xtable(x); - // maskSum is the xor sum of all masks - let mut mask_sum: u128 = 0; - let mut masked_xtable: MXTableFull = vec![[0u128; 2]; 128]; - for i in 0..128 { - let mask: u128 = rng.gen(); - mask_sum ^= mask; - masked_xtable[i][0] = mask; - masked_xtable[i][1] = x_table[i] ^ mask; - } - (masked_xtable, mask_sum) -} - -/// Returns the XOR sum of all elements of the vector. -pub fn xor_sum(vec: &Vec) -> u128 { - vec.iter().fold(0u128, |acc, x| acc ^ x) -} - -/// Converts a flat vector into a vector of chunks of the needed size. -pub fn flat_to_chunks(flat: &Vec, chunk_size: usize) -> Vec> -where - T: Clone, -{ - let count = flat.len() / chunk_size; - let mut vec_chunks: Vec> = Vec::with_capacity(count); - for chunk in flat.chunks(chunk_size) { - vec_chunks.push(chunk.to_vec()); - } - vec_chunks -} - -#[cfg(test)] -mod tests { - use super::*; - use ghash_rc::{ - universal_hash::{NewUniversalHash, UniversalHash}, - GHash, - }; - use rand::{thread_rng, Rng, SeedableRng}; - use rand_chacha::ChaCha12Rng; - use std::convert::TryInto; - - #[test] - fn test_block_mult() { - let mut rng = thread_rng(); - let x: u128 = rng.gen(); - let y: u128 = rng.gen(); - assert_eq!(block_mult(x, y), rust_crypto_ghash(x, &vec![y])); - } - - #[test] - fn test_free_square() { - let mut rng = thread_rng(); - let x: u128 = rng.gen(); - assert_eq!(free_square(x), rust_crypto_ghash(x, &vec![x])); - } - - #[test] - fn test_square_all() { - let mut powers_keys: Vec = vec![1, 3, 5, 6]; - let mut max_power = 8; - let mut powers = setup_square_all(powers_keys); - let mut res = square_all(&powers, max_power); - let mut res_keys: Vec = res.keys().cloned().collect(); - let mut res_values: Vec = res.values().cloned().collect(); - assert_eq!(res_keys, vec![1, 2, 3, 4, 5, 6, 8]); - assert_eq!( - res_values, - vec![ - 12346, - 37384537381450758925419573570612497787, - 12348, - 330354857586696702251049094163216396600, - 12350, - 12351, - 226635212255694396134298606764692048245 - ] - ); - - powers_keys = vec![1, 101]; - max_power = 255; - powers = setup_square_all(powers_keys); - res = square_all(&powers, max_power); - res_keys = res.keys().cloned().collect(); - res_values = res.values().cloned().collect(); - assert_eq!(res_keys, vec![1, 2, 4, 8, 16, 32, 64, 101, 128, 202]); - assert_eq!( - res_values, - vec![ - 12346, - 37384537381450758925419573570612497787, - 330354857586696702251049094163216396600, - 226635212255694396134298606764692048245, - 57435732663173249101181542372863021389, - 221430791223604693286277902449016542898, - 102715545892785352441374614451375486130, - 12446, - 308951395680463668816102109032313818778, - 144387391042136486694176041923180290643 - ] - ); - } - - #[test] - #[should_panic] - fn test_square_all_should_panic() { - let powers_keys: Vec = vec![1, 3, 5, 6, 9]; - let max_power = 8; - let powers = setup_square_all(powers_keys); - // will panick because we have power 9 when max needed is 8 - square_all(&powers, max_power); - } - - #[test] - fn test_find_sum() { - let summands = setup_find_sum(); - assert_eq!(find_sum(&summands, 8), (3, 5)); - assert_eq!(find_sum(&summands, 14), (5, 9)); - assert_eq!(find_sum(&summands, 21), (6, 15)); - } - - #[test] - #[should_panic] - fn test_find_sum_should_panic() { - let summands = setup_find_sum(); - // will panick because summands must be non-equal, so (15, 15) is not - // allowed - find_sum(&summands, 30); - // will panick because no two summands add up to 25 - find_sum(&summands, 25); - } - - #[test] - fn test_find_max_odd_power() { - assert_eq!(find_max_odd_power(1), 3); - assert_eq!(find_max_odd_power(20), 5); - assert_eq!(find_max_odd_power(100), 11); - assert_eq!(find_max_odd_power(1000), 31); - } - - #[test] - fn test_multiply_powers_and_blocks() { - let mut rng = thread_rng(); - let block_count = 10; - let h: u128 = rng.gen(); - let powers_of_h = compute_expected_powers(h, 10); - let mut powers: BTreeMap = BTreeMap::new(); - let mut blocks: Vec = Vec::new(); - // init all powers and all blocks - for i in 0..block_count { - powers.insert(i + 1, powers_of_h[(i + 1) as usize]); - blocks.push(rng.gen()); - } - let result = multiply_powers_and_blocks(&powers, &blocks); - assert_eq!(result, rust_crypto_ghash(h, &blocks)); - } - - #[test] - #[should_panic] - fn test_multiply_powers_and_blocks_should_panic() { - let mut rng = thread_rng(); - let block_count = 10; - let h: u128 = rng.gen(); - let powers_of_h = compute_expected_powers(h, block_count); - let mut powers: BTreeMap = BTreeMap::new(); - let mut blocks: Vec = Vec::new(); - // init all powers and all blocks - for i in 0..block_count { - powers.insert(i + 1, powers_of_h[(i + 1) as usize]); - blocks.push(rng.gen()); - } - // insert an extra power to have more powers than blocks which is a - // sign of a logic error. - powers.insert( - block_count + 1, - block_mult(powers_of_h[block_count as usize], h), - ); - multiply_powers_and_blocks(&powers, &blocks); - } - - #[test] - fn test_block_aggregation() { - let h: u128 = 123456; - let block_count = 10; - let powers_of_h = compute_expected_powers(h, block_count); - let mut powers_map: BTreeMap = BTreeMap::new(); - let mut blocks: Vec = Vec::with_capacity(block_count as usize); - for i in 1..block_count + 1 { - powers_map.insert(i, powers_of_h[i as usize]); - blocks.push(1234567 + i as u128); - } - // all powers are in place, so no block aggregation will happen - assert_eq!( - block_aggregation(&powers_map, &blocks), - (BTreeMap::new(), 0) - ); - // remove some powers - powers_map.remove(&5); - powers_map.remove(&7); - let mut expected_map: BTreeMap = BTreeMap::new(); - expected_map.insert(1, 6529972824624832318862907648013721286); - assert_eq!( - block_aggregation(&powers_map, &blocks), - (expected_map, 315833047958356732231847338615588728787) - ); - } - - #[test] - fn test_block_aggregation_bits() { - let mut powers_map: BTreeMap = BTreeMap::new(); - let mut aggregated_map: BTreeMap = BTreeMap::new(); - powers_map.insert(1, 256); - aggregated_map.insert(1, 512); - let mut expected: [[bool; 128]; 2] = [[false; 128]; 2]; - expected[0][128 - 9] = true; // set bit for 256 - expected[1][128 - 10] = true; // set bit for 512 - assert_eq!( - expected.concat(), - block_aggregation_bits(&powers_map, &aggregated_map).concat() - ); - } - - #[test] - fn test_block_aggregation_mxtables() { - let mut rng = ChaCha12Rng::seed_from_u64(12345); - - let mut powers_map: BTreeMap = BTreeMap::new(); - let mut aggregated_map: BTreeMap = BTreeMap::new(); - powers_map.insert(1, 256); - aggregated_map.insert(1, 512); - let mut expected: [[bool; 128]; 2] = [[false; 128]; 2]; - expected[0][128 - 9] = true; // set bit for 256 - expected[1][128 - 10] = true; // set bit for 512 - let (mxtables, share) = block_aggregation_mxtables(&mut rng, &powers_map, &aggregated_map); - assert_eq!(share, 119435139769675579125133100514879089925); - // since mxtables output is huge, we check only a few arbitrary elements of it - assert_eq!( - mxtables.concat()[23], - [ - 87470858790173581075227934021272140429, - 87445059565157736150448673117636328077 - ] - ); - assert_eq!( - mxtables.concat()[54], - [ - 332311817981764475918232627385210634271, - 332311817981747626514621748491088166943 - ] - ); - assert_eq!( - mxtables.concat()[10], - [ - 111226209635646018502889366408372405024, - 237502869235213026428751037135005139744 - ] - ); - } - - #[test] - fn test_xtable() { - let result = xtable(123456u128); - // since xtable output is huge, we check only a few arbitrary elements of it - assert_eq!(result[0], 123456); - assert_eq!(result[39], 57076772936301564299567746252800); - assert_eq!(result[111], 12086476800); - } - - #[test] - fn test_masked_xtable() { - let mut rng = thread_rng(); - let x: u128 = rng.gen(); - let y: u128 = rng.gen(); - let expected = block_mult(x, y); - assert_eq!(expected, product_from_shares(x, y)); - - // corrupt some bytes of y value - let mut bad_bytes = y.to_be_bytes(); - bad_bytes[5] = bad_bytes[5].checked_add(1).unwrap_or_default(); - bad_bytes[10] = bad_bytes[10].checked_add(1).unwrap_or_default(); - bad_bytes[15] = bad_bytes[15].checked_add(1).unwrap_or_default(); - let bad = u128::from_be_bytes(bad_bytes); - assert_ne!(expected, product_from_shares(x, bad)); - } - - #[test] - fn test_xor_sum() { - let mut rng = thread_rng(); - let mut summands: Vec = Vec::new(); - for _i in 0..300 { - let rand = rng.gen(); - summands.push(rand); - summands.push(rand); - } - // xoring the same value twice should result in zero - assert_eq!(xor_sum(&summands), 0); - summands.push(123456); - assert_eq!(xor_sum(&summands), 123456); - } - - // compute GHASH using RustCrypto's ghash - fn rust_crypto_ghash(h: u128, blocks: &Vec) -> u128 { - let mut ghash = GHash::new(&h.to_be_bytes().into()); - for block in blocks.iter() { - ghash.update(&block.to_be_bytes().into()); - } - let b = ghash.finalize().into_bytes(); - u128::from_be_bytes(b.as_slice().try_into().unwrap()) - } - - // prepare the expected powers of h by recursively multiplying h to - // itself - fn compute_expected_powers(h: u128, max: u16) -> Vec { - // prepare the expected powers of h by recursively multiplying h to - // itself - let mut powers: Vec = vec![0u128; (max + 1) as usize]; - powers[1] = h; - let mut prev_power = h; - for i in 2..((max + 1) as usize) { - powers[i] = block_mult(prev_power, h); - prev_power = powers[i]; - } - powers - } - - fn setup_find_sum() -> BTreeMap { - let summands_keys: Vec = vec![1, 3, 5, 6, 8, 9, 12, 15]; - let mut summands: BTreeMap = BTreeMap::new(); - // assign any value > 0 to elements at keys corresponding to - // summands_keys - for v in summands_keys.iter() { - summands.insert(*v, (12345 + *v) as u128); - } - summands - } - - fn setup_square_all(keys: Vec) -> BTreeMap { - let mut powers: BTreeMap = BTreeMap::new(); - // assign any value to elements at keys corresponding to - // powers_keys - for v in keys.iter() { - powers.insert(*v, (12345 + *v) as u128); - } - powers - } - - fn product_from_shares(x: u128, y: u128) -> u128 { - // instantiate with empty values, we only need rng for this test - let (masked_xtable, my_product_share) = masked_xtable(&mut thread_rng(), x); - - // the other party who has the y value will receive only 1 value (out of 2) - // for each entry in maskedXTable via Oblivious Transfer depending on the - // bits of y. We simulate that here: - let mut his_product_share = 0u128; - let bits = u8vec_to_boolvec(&y.to_be_bytes()); - for i in 0..128 { - // the first element in xTable corresponds to the highest bit of y - his_product_share ^= masked_xtable[i][bits[i] as usize]; - } - my_product_share ^ his_product_share - } -}