mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 21:38:00 -05:00
implement GHASH 2PC (#20)
* implement GHASH 2PC * receiver_tls test * make GHASH more generic * wrap temp_share in an Option<> * fix warnings * make modules public * add traits * add feature Co-authored-by: themighty1 <you@example.com>
This commit is contained in:
@@ -8,14 +8,19 @@ name = "tls_core"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[features]
|
||||
default = ["prf"]
|
||||
default = ["prf", "ghash"]
|
||||
prf = []
|
||||
ghash = []
|
||||
|
||||
[dependencies]
|
||||
tlsn-mpc-core = { path = "../mpc-core" }
|
||||
sha2 = { version = "0.10.1", features = ["compress"] }
|
||||
digest = { version = "0.10.3" }
|
||||
hmac = { version = "0.12.1" }
|
||||
rand = "0.8.5"
|
||||
thiserror = "1.0.30"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.5"
|
||||
ghash_rc = { package = "ghash", version = "0.4.4" }
|
||||
rand_chacha = "0.3.1"
|
||||
|
||||
151
tls-core/src/ghash/common.rs
Normal file
151
tls-core/src/ghash/common.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
use super::errors::*;
|
||||
use super::utils::{find_max_odd_power, find_sum, square_all};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
// GhashCommon is common to both Master and Slave
|
||||
pub struct GhashCommon {
|
||||
// blocks are input blocks for GHASH. In TLS the first block is AAD,
|
||||
// the middle blocks are AES blocks - the ciphertext, the last block
|
||||
// is len(AAD)+len(ciphertext)
|
||||
pub blocks: Vec<u128>,
|
||||
// powers are our XOR shares of the powers of H (H is the GHASH key).
|
||||
// We need as many powers as there blocks. Value at key==1 corresponds to the share
|
||||
// of H^1, value at key==2 to the share of H^2 etc.
|
||||
pub powers: BTreeMap<u16, u128>,
|
||||
// max_odd_power is the maximum odd power that we'll need to compute
|
||||
// GHASH in 2PC using Block Aggregation
|
||||
pub max_odd_power: u8,
|
||||
// strategies are initialized in ::new(). See comments there.
|
||||
pub strategies: [BTreeMap<u8, [u8; 2]>; 2],
|
||||
// temp_share is used to save an intermediate GHASH share
|
||||
pub temp_share: Option<u128>,
|
||||
}
|
||||
|
||||
impl GhashCommon {
|
||||
pub fn new(ghash_key_share: u128, blocks: Vec<u128>) -> Result<Self, GhashError> {
|
||||
if blocks.len() < 3 || blocks.len() > 1026 {
|
||||
return Err(GhashError::BlockCountWrong);
|
||||
}
|
||||
let mut powers = BTreeMap::new();
|
||||
// GHASH key is our share H^1
|
||||
powers.insert(1, ghash_key_share);
|
||||
let max_odd_power = find_max_odd_power(blocks.len() as u16);
|
||||
powers = square_all(&powers, blocks.len() as u16);
|
||||
|
||||
// strategy1 and startegy2 are only relevant for the Block Aggregation method.
|
||||
// They show what existing shares of the powers of H (H is the GHASH key) we
|
||||
// will be multiplying (value[0] and value[1]) to obtain other odd shares (<key>).
|
||||
// Max sequential odd share that we can obtain on first round of
|
||||
// communication is 19. We already have 1) shares of H^1, H^2, H^3 from
|
||||
// the Client Finished message and 2) squares of those 3 shares.
|
||||
// Note that "sequential" is a keyword here. We can't obtain 21 but we
|
||||
// indeed can obtain 25==24+1, 33==32+1 etc. However with 21 missing,
|
||||
// even if we have 25,33,etc, there will be a gap and we will not be able
|
||||
// to obtain all the needed shares by Block Aggregation.
|
||||
|
||||
// We request OT for each share in each pair of the strategy, i.e. for
|
||||
// shares: 4,1,4,3,8,1, etc. Even though it would be possible to introduce
|
||||
// optimizations in order to avoid requesting OT for the same share more
|
||||
// than once, that would only save us ~2000 OT instances at the cost of
|
||||
// complicating the code.
|
||||
|
||||
let strategy1: BTreeMap<u8, [u8; 2]> = BTreeMap::from([
|
||||
(5, [4, 1]),
|
||||
(7, [4, 3]),
|
||||
(9, [8, 1]),
|
||||
(11, [8, 3]),
|
||||
(13, [12, 1]),
|
||||
(15, [12, 3]),
|
||||
(17, [16, 1]),
|
||||
(19, [16, 3]),
|
||||
]);
|
||||
let strategy2: BTreeMap<u8, [u8; 2]> = BTreeMap::from([
|
||||
(21, [17, 4]),
|
||||
(23, [17, 6]),
|
||||
(25, [17, 8]),
|
||||
(27, [19, 8]),
|
||||
(29, [17, 12]),
|
||||
(31, [19, 12]),
|
||||
(33, [17, 16]),
|
||||
(35, [19, 16]),
|
||||
]);
|
||||
|
||||
Ok(Self {
|
||||
max_odd_power,
|
||||
blocks,
|
||||
powers,
|
||||
strategies: [strategy1, strategy2],
|
||||
temp_share: Some(0u128),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the amount of Oblivious Transfer instances needed to complete
|
||||
/// the protocol. The purpose is to inform the OT layer.
|
||||
pub fn calculate_ot_count(&mut self) -> usize {
|
||||
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
// only 2 2PC multiplications in round 1
|
||||
let r1count = 2;
|
||||
powers.insert(1, 0u128);
|
||||
powers.insert(2, 0u128);
|
||||
powers.insert(3, 0u128);
|
||||
// since we just added a few new shares of powers, we need to square them
|
||||
self.powers = square_all(&powers, self.blocks.len() as u16);
|
||||
|
||||
// merge 2 strategies maps into 1
|
||||
let strategy: BTreeMap<u8, [u8; 2]> = self.strategies[0]
|
||||
.clone()
|
||||
.into_iter()
|
||||
.chain(self.strategies[1].clone())
|
||||
.collect();
|
||||
// number of multiplications in rounds 2 and 3
|
||||
let mut r2and3count = 0;
|
||||
for (key, _) in strategy.iter() {
|
||||
if *key > self.max_odd_power {
|
||||
break;
|
||||
};
|
||||
powers.insert(*key as u16, 0u128);
|
||||
r2and3count += 2;
|
||||
}
|
||||
// since we just added a few new shares of powers, we need to square them
|
||||
powers = square_all(&powers, self.blocks.len() as u16);
|
||||
|
||||
let mut aggregated: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
for i in 1..self.blocks.len() + 1 {
|
||||
if powers.get(&(i as u16)) != None {
|
||||
continue;
|
||||
}
|
||||
let (small, _) = find_sum(&powers, i as u16);
|
||||
// initialize the value if it doesn't exist
|
||||
if aggregated.get(&small) == None {
|
||||
aggregated.insert(small, 0u128);
|
||||
}
|
||||
}
|
||||
let r4count = aggregated.len() * 2;
|
||||
// each multiplication requires 128 instances of Oblivious Transfer
|
||||
(r1count + r2and3count + r4count) * 128
|
||||
}
|
||||
|
||||
pub fn export_powers(&mut self) -> BTreeMap<u16, u128> {
|
||||
self.powers.clone()
|
||||
}
|
||||
|
||||
pub fn is_round2_needed(&self) -> bool {
|
||||
// after round 1 we will have consecutive powers 1,2,3 which is enough
|
||||
// to compute GHASH for 19 blocks with block aggregation.
|
||||
self.blocks.len() > 19
|
||||
}
|
||||
|
||||
pub fn is_round3_needed(&self) -> bool {
|
||||
// after round 2 we will have a max of up to 19 consequitive odd powers
|
||||
// which allows us to get 339 powers with block aggregation, see max_htable
|
||||
// in utils::find_max_odd_power()
|
||||
self.blocks.len() > 339
|
||||
}
|
||||
|
||||
pub fn is_only_1_round(&self) -> bool {
|
||||
// block agregation is always used except for very small block count
|
||||
// where powers from round 1 are sufficient to perform direct multiplication
|
||||
// of blocks by powers
|
||||
self.blocks.len() <= 4
|
||||
}
|
||||
}
|
||||
12
tls-core/src/ghash/errors.rs
Normal file
12
tls-core/src/ghash/errors.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
/// Errors that may occur when using ghash module
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum GhashError {
|
||||
#[error("Message was received out of order")]
|
||||
OutOfOrder,
|
||||
#[error("The other party sent data of wrong size")]
|
||||
DataLengthWrong,
|
||||
#[error("Tried to pass unsupported block count")]
|
||||
BlockCountWrong,
|
||||
#[error("Tried to finalize before the protocol was complete")]
|
||||
FinalizeCalledTooEarly,
|
||||
}
|
||||
245
tls-core/src/ghash/master.rs
Normal file
245
tls-core/src/ghash/master.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
//! Implements the GHASH Master. This is the party which holds the Y value of
|
||||
//! block multiplication. Master acts as the receiver of the Oblivious
|
||||
//! Transfer and receives Slaves's masked X table entries obliviously for each
|
||||
//! bit of Y.
|
||||
use super::utils::{
|
||||
block_aggregation, block_aggregation_bits, block_mult, flat_to_chunks,
|
||||
multiply_powers_and_blocks, square_all, xor_sum,
|
||||
};
|
||||
use super::{errors::*, MasterCore};
|
||||
use crate::ghash::common::GhashCommon;
|
||||
use crate::ghash::{MXTable, YBits};
|
||||
use mpc_core::utils::u8vec_to_boolvec;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum MasterState {
|
||||
Initialized,
|
||||
// There may be 1, 2, 3 or 4 rounds depending on GHASH block count
|
||||
RoundSent(usize),
|
||||
RoundReceived(usize),
|
||||
Complete,
|
||||
}
|
||||
|
||||
pub struct GhashMaster {
|
||||
c: GhashCommon,
|
||||
state: MasterState,
|
||||
// is_last_round will be set to true by next_request() to indicate that
|
||||
// after the response is received the state must be set to Complete
|
||||
is_last_round: bool,
|
||||
}
|
||||
|
||||
impl MasterCore for GhashMaster {
|
||||
fn next_request(&mut self) -> Result<Vec<bool>, GhashError> {
|
||||
let retval;
|
||||
let is_complete;
|
||||
match self.state {
|
||||
MasterState::Initialized => {
|
||||
self.state = MasterState::RoundSent(1);
|
||||
retval = self.get_ybits_for_round1().concat();
|
||||
is_complete = self.is_only_1_round();
|
||||
}
|
||||
MasterState::RoundReceived(1) => {
|
||||
if self.is_round2_needed() {
|
||||
self.state = MasterState::RoundSent(2);
|
||||
retval = self.get_ybits_for_round(2).concat();
|
||||
is_complete = false;
|
||||
} else {
|
||||
// rounds 2 and 3 will be skipped
|
||||
self.state = MasterState::RoundSent(4);
|
||||
retval = self.get_ybits_for_block_aggr().concat();
|
||||
is_complete = true;
|
||||
}
|
||||
}
|
||||
MasterState::RoundReceived(2) => {
|
||||
if self.is_round3_needed() {
|
||||
self.state = MasterState::RoundSent(3);
|
||||
retval = self.get_ybits_for_round(3).concat();
|
||||
is_complete = false;
|
||||
} else {
|
||||
// round 3 will be skipped
|
||||
self.state = MasterState::RoundSent(4);
|
||||
retval = self.get_ybits_for_block_aggr().concat();
|
||||
is_complete = true;
|
||||
}
|
||||
}
|
||||
MasterState::RoundReceived(3) => {
|
||||
self.state = MasterState::RoundSent(4);
|
||||
retval = self.get_ybits_for_block_aggr().concat();
|
||||
is_complete = true;
|
||||
}
|
||||
_ => {
|
||||
return Err(GhashError::OutOfOrder);
|
||||
}
|
||||
}
|
||||
if is_complete {
|
||||
self.is_last_round = true;
|
||||
}
|
||||
Ok(retval)
|
||||
}
|
||||
|
||||
fn process_response(&mut self, response: &Vec<u128>) -> Result<(), GhashError> {
|
||||
if response.len() % 128 != 0 {
|
||||
return Err(GhashError::DataLengthWrong);
|
||||
}
|
||||
let mxtables = flat_to_chunks(response, 128);
|
||||
match self.state {
|
||||
MasterState::RoundSent(1) => {
|
||||
self.state = MasterState::RoundReceived(1);
|
||||
self.process_mxtables_for_round1(&mxtables);
|
||||
}
|
||||
MasterState::RoundSent(2) => {
|
||||
self.state = MasterState::RoundReceived(2);
|
||||
self.process_mxtables_for_round(&mxtables, 2);
|
||||
}
|
||||
MasterState::RoundSent(3) => {
|
||||
self.state = MasterState::RoundReceived(3);
|
||||
self.process_mxtables_for_round(&mxtables, 3);
|
||||
}
|
||||
MasterState::RoundSent(4) => {
|
||||
self.state = MasterState::RoundReceived(4);
|
||||
self.process_mxtables_for_block_aggr(&mxtables);
|
||||
}
|
||||
_ => {
|
||||
return Err(GhashError::OutOfOrder);
|
||||
}
|
||||
}
|
||||
if self.is_last_round {
|
||||
if self.state != MasterState::RoundReceived(4) {
|
||||
// if the last round was not round 4 (i.e. there was no block
|
||||
// aggregation), then we compute GHASH directly
|
||||
self.c.temp_share =
|
||||
Some(multiply_powers_and_blocks(&self.c.powers, &self.c.blocks));
|
||||
}
|
||||
self.state = MasterState::Complete;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_complete(&mut self) -> bool {
|
||||
self.state == MasterState::Complete
|
||||
}
|
||||
|
||||
fn finalize(&mut self) -> Result<u128, GhashError> {
|
||||
if self.state != MasterState::Complete {
|
||||
return Err(GhashError::FinalizeCalledTooEarly);
|
||||
}
|
||||
Ok(self.c.temp_share.unwrap())
|
||||
}
|
||||
|
||||
/// Returns the amount of Oblivious Transfer instances needed to complete
|
||||
/// the protocol. The purpose is to inform the OT layer.
|
||||
fn calculate_ot_count(&mut self) -> usize {
|
||||
self.c.calculate_ot_count()
|
||||
}
|
||||
|
||||
fn export_powers(&mut self) -> BTreeMap<u16, u128> {
|
||||
self.c.export_powers()
|
||||
}
|
||||
}
|
||||
|
||||
impl GhashMaster {
|
||||
pub fn new(ghash_key_share: u128, blocks: Vec<u128>) -> Result<Self, GhashError> {
|
||||
let common = GhashCommon::new(ghash_key_share, blocks)?;
|
||||
Ok(Self {
|
||||
c: common,
|
||||
state: MasterState::Initialized,
|
||||
is_last_round: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_round2_needed(&self) -> bool {
|
||||
self.c.is_round2_needed()
|
||||
}
|
||||
|
||||
fn is_round3_needed(&self) -> bool {
|
||||
self.c.is_round3_needed()
|
||||
}
|
||||
|
||||
fn is_only_1_round(&self) -> bool {
|
||||
self.c.is_only_1_round()
|
||||
}
|
||||
|
||||
/// Returns Y bits to compute H^3.
|
||||
fn get_ybits_for_round1(&mut self) -> Vec<YBits> {
|
||||
vec![
|
||||
u8vec_to_boolvec(&self.c.powers[&1].to_be_bytes()),
|
||||
u8vec_to_boolvec(&self.c.powers[&2].to_be_bytes()),
|
||||
]
|
||||
}
|
||||
|
||||
/// Takes masked X tables and computes our share of H^3.
|
||||
fn process_mxtables_for_round1(&mut self, mxtables: &Vec<MXTable>) {
|
||||
// the XOR sum of all masked xtables' values plus H^1*H^2 is our share of H^3
|
||||
self.c.powers.insert(
|
||||
3,
|
||||
xor_sum(&mxtables[0])
|
||||
^ xor_sum(&mxtables[1])
|
||||
^ block_mult(self.c.powers[&1], self.c.powers[&2]),
|
||||
);
|
||||
// since we just added a new share of powers, we need to square them
|
||||
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
|
||||
}
|
||||
|
||||
// Returns Y bits for a given round of communication.
|
||||
fn get_ybits_for_round(&mut self, round_no: u8) -> Vec<YBits> {
|
||||
assert!(round_no == 2 || round_no == 3);
|
||||
let mut bits: Vec<YBits> = Vec::new();
|
||||
for (key, value) in self.c.strategies[(round_no - 2) as usize].iter() {
|
||||
if *key > self.c.max_odd_power {
|
||||
break;
|
||||
}
|
||||
bits.push(u8vec_to_boolvec(
|
||||
&self.c.powers[&(value[0] as u16)].to_be_bytes(),
|
||||
));
|
||||
bits.push(u8vec_to_boolvec(
|
||||
&self.c.powers[&(value[1] as u16)].to_be_bytes(),
|
||||
));
|
||||
}
|
||||
bits
|
||||
}
|
||||
|
||||
/// Processes masked X tables for a given round of communication.
|
||||
fn process_mxtables_for_round(&mut self, mxtables: &Vec<MXTable>, round_no: u8) {
|
||||
assert!(round_no == 2 || round_no == 3);
|
||||
for (count, (power, factors)) in self.c.strategies[(round_no - 2) as usize]
|
||||
.iter()
|
||||
.enumerate()
|
||||
{
|
||||
if *power > self.c.max_odd_power {
|
||||
// for every key in the strategy which we processed, there
|
||||
// must have been 2 masked xtables
|
||||
assert!(count * 2 == mxtables.len());
|
||||
break;
|
||||
}
|
||||
// the XOR sum of 2 masked xtables' values plus the locally computed
|
||||
// term is our share of power
|
||||
let sum = xor_sum(&mxtables[count * 2]) ^ xor_sum(&mxtables[(count * 2) + 1]);
|
||||
let local_term = block_mult(
|
||||
self.c.powers[&(factors[0] as u16)],
|
||||
self.c.powers[&(factors[1] as u16)],
|
||||
);
|
||||
self.c.powers.insert(*power as u16, sum ^ local_term);
|
||||
}
|
||||
// since we just added a few new shares of powers, we need to square them
|
||||
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
|
||||
}
|
||||
|
||||
/// Returns Y bits for the block aggregation method.
|
||||
fn get_ybits_for_block_aggr(&mut self) -> Vec<YBits> {
|
||||
let share1 = multiply_powers_and_blocks(&self.c.powers, &self.c.blocks);
|
||||
let (aggregated, share2) = block_aggregation(&self.c.powers, &self.c.blocks);
|
||||
let choice_bits = block_aggregation_bits(&self.c.powers, &aggregated);
|
||||
self.c.temp_share = Some(share1 ^ share2);
|
||||
choice_bits
|
||||
}
|
||||
|
||||
/// Processes masked X tables for the block aggregation method.
|
||||
fn process_mxtables_for_block_aggr(&mut self, mxtables: &Vec<MXTable>) {
|
||||
let mut share = 0u128;
|
||||
for table in mxtables.iter() {
|
||||
share ^= xor_sum(table);
|
||||
}
|
||||
self.c.temp_share = Some(self.c.temp_share.unwrap() ^ share);
|
||||
}
|
||||
}
|
||||
295
tls-core/src/ghash/mod.rs
Normal file
295
tls-core/src/ghash/mod.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
//! ghash implements a protocol of computing the AES-GCM's GHASH function in a
|
||||
//! secure two-party computation (2PC) setting using 1-out-of-2 Oblivious
|
||||
//! Transfer (OT). The parties start with their secret XOR shares of H (the
|
||||
//! GHASH key) and at the end each gets their XOR share of the GHASH output.
|
||||
//! The method is decribed here:
|
||||
//! (https://tlsnotary.org/how_it_works#section4).
|
||||
|
||||
//! As an illustration, let's say that S has his shares H1_s and H2_s and R
|
||||
//! has her shares H1_r and H2_r. They need to compute shares of H3.
|
||||
//! H3 = (H1_s + H1_r)*(H2_s + H2_r) = H1_s*H2_s + H1_s*H2_r + H1_r*H2_s +
|
||||
//! H1_r*H2_r. Term 1 can be computed by S locally and term 4 can be
|
||||
//! computed by R locally. Only terms 2 and 3 will be computed using
|
||||
//! GHASH 2PC. R will obliviously request values for bits of H1_r and H2_r.
|
||||
//! The XOR sum of all values which S will send back plus H1_r*H2_r will
|
||||
//! become R's share of H3.
|
||||
//!
|
||||
//! When performing block multiplication in 2PC, Master holds the Y value and
|
||||
//! Slave holds the X value. The Slave then computes the X table and masks it.
|
||||
|
||||
mod common;
|
||||
pub mod errors;
|
||||
pub mod master;
|
||||
pub mod slave;
|
||||
mod utils;
|
||||
|
||||
use errors::*;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// MXTableFull is masked XTable which Slave has at the beginning of OT.
|
||||
/// MXTableFull must not be revealed to Master.
|
||||
type MXTableFull = Vec<[u128; 2]>;
|
||||
/// MXTable is a masked x table which Master will end up having after OT.
|
||||
type MXTable = Vec<u128>;
|
||||
/// YBits are Master's bits of Y in big-endian. Based on these bits
|
||||
/// Master will send MXTable via OT.
|
||||
/// The convention for the returned Y bits:
|
||||
/// A) powers are in an ascending order: first powers[1], then powers[2] etc.
|
||||
/// B) bits of each power are in big-endian.
|
||||
type YBits = Vec<bool>;
|
||||
|
||||
pub trait MasterCore {
|
||||
/// Returns choice bits for Oblivious Transfer.
|
||||
/// While is_complete() returns false, next_request() must be called
|
||||
/// followed by process_response().
|
||||
fn next_request(&mut self) -> Result<Vec<bool>, GhashError>;
|
||||
|
||||
/// process_response() will be invoked by the Oblivious Transfer impl. It
|
||||
/// receives masked X tables acc. to our choice bits in next_request().
|
||||
fn process_response(&mut self, response: &Vec<u128>) -> Result<(), GhashError>;
|
||||
|
||||
/// Returns true when the protocol is complete.
|
||||
fn is_complete(&mut self) -> bool;
|
||||
|
||||
/// Returns our GHASH share.
|
||||
fn finalize(&mut self) -> Result<u128, GhashError>;
|
||||
|
||||
/// Returns the amount of Oblivious Transfer instances needed to complete
|
||||
/// the protocol. The purpose is to inform the OT layer.
|
||||
fn calculate_ot_count(&mut self) -> usize;
|
||||
|
||||
/// Exports powers of the GHASH key obtained at the current stage of the
|
||||
/// protocol.
|
||||
fn export_powers(&mut self) -> BTreeMap<u16, u128>;
|
||||
}
|
||||
|
||||
pub trait SlaveCore {
|
||||
/// Returns the full masked X table which must NOT be passed to Master. It
|
||||
/// must be consumed by the Oblivious Transfer impl.
|
||||
fn process_request(&mut self) -> Result<Vec<[u128; 2]>, GhashError>;
|
||||
|
||||
/// Returns true when the protocol is complete.
|
||||
fn is_complete(&mut self) -> bool;
|
||||
|
||||
/// Returns our GHASH share.
|
||||
fn finalize(&mut self) -> Result<u128, GhashError>;
|
||||
|
||||
/// Returns the amount of Oblivious Transfer instances needed to complete
|
||||
/// the protocol. The purpose is to inform the OT layer.
|
||||
fn calculate_ot_count(&mut self) -> usize;
|
||||
|
||||
/// Exports powers of the GHASH key obtained at the current stage of the
|
||||
/// protocol.
|
||||
fn export_powers(&mut self) -> BTreeMap<u16, u128>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::errors::GhashError;
|
||||
use super::utils::block_mult;
|
||||
use super::{master::GhashMaster, slave::GhashSlave, MasterCore, SlaveCore};
|
||||
use ghash_rc::{
|
||||
universal_hash::{NewUniversalHash, UniversalHash},
|
||||
GHash,
|
||||
};
|
||||
use rand::prelude::ThreadRng;
|
||||
use rand::{thread_rng, Rng};
|
||||
use std::convert::TryInto;
|
||||
|
||||
#[test]
|
||||
// test only round 1
|
||||
fn test_round1() {
|
||||
let block_count = 3;
|
||||
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
let ghash = finalize(&mut slave, &mut master);
|
||||
assert_eq!(ghash, rust_crypto_ghash(h, &blocks));
|
||||
assert!(master.is_complete());
|
||||
assert!(slave.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
// test state after rounds 1,2 (but before block aggregation)
|
||||
fn test_round12_before_block_aggregation() {
|
||||
let block_count = 30;
|
||||
let (h, mut slave, mut master, _blocks) = ghash_setup(block_count);
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
|
||||
let s_powers = slave.export_powers();
|
||||
let r_powers = master.export_powers();
|
||||
let all_s_keys: Vec<u16> = s_powers.keys().cloned().collect();
|
||||
let all_r_keys: Vec<u16> = r_powers.keys().cloned().collect();
|
||||
let expected_keys = vec![1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28];
|
||||
let exp_powers = compute_expected_powers(h, block_count as u16);
|
||||
|
||||
assert_eq!(all_s_keys, expected_keys);
|
||||
assert_eq!(all_r_keys, expected_keys);
|
||||
// compare shares of powers against expected powers
|
||||
for key in expected_keys.iter() {
|
||||
assert_eq!(
|
||||
exp_powers[*key as usize],
|
||||
*s_powers.get(key).unwrap() ^ *r_powers.get(key).unwrap()
|
||||
);
|
||||
}
|
||||
assert!(!master.is_complete());
|
||||
assert!(!slave.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
// test rounds 1,2,4
|
||||
fn test_round124() {
|
||||
let block_count = 30;
|
||||
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
let ghash = finalize(&mut slave, &mut master);
|
||||
assert_eq!(ghash, rust_crypto_ghash(h, &blocks));
|
||||
assert!(master.is_complete());
|
||||
assert!(slave.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
// test rounds 1,2,3,4
|
||||
fn test_round1234() {
|
||||
let block_count = 340;
|
||||
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
let ghash = finalize(&mut slave, &mut master);
|
||||
assert_eq!(ghash, rust_crypto_ghash(h, &blocks));
|
||||
assert!(master.is_complete());
|
||||
assert!(slave.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
// test export_powers() after round 1
|
||||
fn test_export_powers() {
|
||||
let block_count = 340;
|
||||
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
|
||||
run_round(&mut slave, &mut master).unwrap();
|
||||
let powers_s = slave.export_powers();
|
||||
let powers_r = master.export_powers();
|
||||
// we only have 4 consecutive powers 1,2,3,4 after round 1
|
||||
// we compute ghash for only the first 4 blocks
|
||||
let mut ghash_share_s = 0u128;
|
||||
for i in 0..4 {
|
||||
ghash_share_s ^= block_mult(*powers_s.get(&(4 - i)).unwrap(), blocks[i as usize]);
|
||||
}
|
||||
let mut ghash_share_r = 0u128;
|
||||
for i in 0..4 {
|
||||
ghash_share_r ^= block_mult(*powers_r.get(&(4 - i)).unwrap(), blocks[i as usize]);
|
||||
}
|
||||
let ghash = ghash_share_s ^ ghash_share_r;
|
||||
assert_eq!(ghash, rust_crypto_ghash(h, &blocks[0..4].to_vec()));
|
||||
assert!(!master.is_complete());
|
||||
assert!(!slave.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
// test OT count against hard-coded values and also double-check against
|
||||
// the actual amount of Y bits which the Master requested.
|
||||
fn test_calculate_ot_count() {
|
||||
let block_counts = vec![3, 19, 200, 1026];
|
||||
// expected amount of 2PC block multiplications
|
||||
let expected = vec![2, 8, 44, 96];
|
||||
|
||||
for i in 0..block_counts.len() {
|
||||
let block_count = block_counts[i];
|
||||
let (_, mut slave, mut master, _) = ghash_setup(block_count);
|
||||
let mut ybits_count = 0;
|
||||
while !master.is_complete() {
|
||||
let req = master.next_request().unwrap();
|
||||
ybits_count += req.len();
|
||||
let resp = slave.process_request().unwrap();
|
||||
let masked_xtable = simulate_ot(&req, &resp);
|
||||
master.process_response(&masked_xtable).unwrap();
|
||||
}
|
||||
let ot_count = master.calculate_ot_count();
|
||||
assert!(ot_count == expected[i] * 128);
|
||||
assert!(ot_count == ybits_count);
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(sender: &mut GhashSlave<ThreadRng>, receiver: &mut GhashMaster) -> u128 {
|
||||
let sender_ghash_share = sender.finalize().unwrap();
|
||||
let receiver_ghash_share = receiver.finalize().unwrap();
|
||||
receiver_ghash_share ^ sender_ghash_share
|
||||
}
|
||||
|
||||
fn ghash_setup(block_count: usize) -> (u128, GhashSlave<ThreadRng>, GhashMaster, Vec<u128>) {
|
||||
let mut rng = thread_rng();
|
||||
// h is ghash key
|
||||
let h: u128 = rng.gen();
|
||||
// h_s is sender's XOR share of h
|
||||
let h_s: u128 = rng.gen();
|
||||
// h_r is receiver's XOR share of h
|
||||
let h_r: u128 = h ^ h_s;
|
||||
|
||||
let blocks: Vec<u128> = random_blocks(block_count);
|
||||
let sender = GhashSlave::new(rng, h_s, blocks.clone()).unwrap();
|
||||
let receiver = GhashMaster::new(h_r, blocks.clone()).unwrap();
|
||||
(h, sender, receiver, blocks)
|
||||
}
|
||||
|
||||
fn random_blocks(block_count: usize) -> Vec<u128> {
|
||||
let mut rng = thread_rng();
|
||||
let mut blocks: Vec<u128> = Vec::new();
|
||||
for _i in 0..block_count {
|
||||
blocks.push(rng.gen());
|
||||
}
|
||||
blocks
|
||||
}
|
||||
|
||||
// compute GHASH using RustCrypto's ghash
|
||||
fn rust_crypto_ghash(h: u128, blocks: &Vec<u128>) -> u128 {
|
||||
let mut ghash = GHash::new(&h.to_be_bytes().into());
|
||||
for block in blocks.iter() {
|
||||
ghash.update(&block.to_be_bytes().into());
|
||||
}
|
||||
let b = ghash.finalize().into_bytes();
|
||||
u128::from_be_bytes(b.as_slice().try_into().unwrap())
|
||||
}
|
||||
|
||||
// prepare the expected powers of h by recursively multiplying h to
|
||||
// itself
|
||||
fn compute_expected_powers(h: u128, max: u16) -> Vec<u128> {
|
||||
// prepare the expected powers of h by recursively multiplying h to
|
||||
// itself
|
||||
let mut powers: Vec<u128> = vec![0u128; (max + 1) as usize];
|
||||
powers[1] = h;
|
||||
let mut prev_power = h;
|
||||
for i in 2..((max + 1) as usize) {
|
||||
powers[i] = block_mult(prev_power, h);
|
||||
prev_power = powers[i];
|
||||
}
|
||||
powers
|
||||
}
|
||||
|
||||
// run_round runs the next round
|
||||
fn run_round(
|
||||
sender: &mut GhashSlave<ThreadRng>,
|
||||
receiver: &mut GhashMaster,
|
||||
) -> Result<(), GhashError> {
|
||||
let receiver_bits = receiver.next_request()?;
|
||||
let masked_xtable_full = sender.process_request()?;
|
||||
let masked_xtable = simulate_ot(&receiver_bits, &masked_xtable_full);
|
||||
receiver.process_response(&masked_xtable)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// normally Master will send his bits via OT to get only 1 out of 2 values
|
||||
// for each row of masked xtable. Here we simulate this OT behaviour.
|
||||
fn simulate_ot(receiver_bits: &Vec<bool>, mxtables_full: &Vec<[u128; 2]>) -> Vec<u128> {
|
||||
assert!(receiver_bits.len() == mxtables_full.len());
|
||||
let mut mxtables: Vec<u128> = Vec::new();
|
||||
for i in 0..mxtables_full.len() {
|
||||
let choice = receiver_bits[i] as usize;
|
||||
mxtables.push(mxtables_full[i][choice]);
|
||||
}
|
||||
mxtables
|
||||
}
|
||||
}
|
||||
183
tls-core/src/ghash/slave.rs
Normal file
183
tls-core/src/ghash/slave.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
//! Implements the GHASH Slave. This is the party which holds the X value of
|
||||
//! block multiplication. Slave acts as the sender of the Oblivious
|
||||
//! Transfer and sends masked x_table entries obliviously for each
|
||||
//! bit of Y received from the GHASH Master.
|
||||
|
||||
use super::common::GhashCommon;
|
||||
use super::errors::*;
|
||||
use super::utils::{
|
||||
block_aggregation, block_aggregation_mxtables, block_mult, free_square, masked_xtable,
|
||||
multiply_powers_and_blocks, square_all,
|
||||
};
|
||||
use super::{MXTableFull, SlaveCore};
|
||||
use rand::{CryptoRng, Rng};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum SlaveState {
|
||||
Initialized,
|
||||
// There may be from 1 to 4 rounds depending on GHASH block count
|
||||
RoundReceived(usize),
|
||||
Complete,
|
||||
}
|
||||
|
||||
pub struct GhashSlave<R> {
|
||||
c: GhashCommon,
|
||||
rng: R,
|
||||
state: SlaveState,
|
||||
}
|
||||
|
||||
impl<R: Rng + CryptoRng> SlaveCore for GhashSlave<R> {
|
||||
fn process_request(&mut self) -> Result<Vec<[u128; 2]>, GhashError> {
|
||||
let retval;
|
||||
let is_complete;
|
||||
match self.state {
|
||||
SlaveState::Initialized => {
|
||||
self.state = SlaveState::RoundReceived(1);
|
||||
retval = self.get_mxtables_for_round1().concat();
|
||||
is_complete = self.is_only_1_round();
|
||||
}
|
||||
SlaveState::RoundReceived(1) => {
|
||||
if self.is_round2_needed() {
|
||||
self.state = SlaveState::RoundReceived(2);
|
||||
retval = self.get_mxtables_for_round(2).concat();
|
||||
is_complete = false;
|
||||
} else {
|
||||
// rounds 2 and 3 will be skipped
|
||||
self.state = SlaveState::RoundReceived(4);
|
||||
retval = self.get_mxtables_for_block_aggr().concat();
|
||||
is_complete = true;
|
||||
}
|
||||
}
|
||||
SlaveState::RoundReceived(2) => {
|
||||
if self.is_round3_needed() {
|
||||
self.state = SlaveState::RoundReceived(3);
|
||||
retval = self.get_mxtables_for_round(3).concat();
|
||||
is_complete = false;
|
||||
} else {
|
||||
// round 3 will be skipped
|
||||
self.state = SlaveState::RoundReceived(4);
|
||||
retval = self.get_mxtables_for_block_aggr().concat();
|
||||
is_complete = true;
|
||||
}
|
||||
}
|
||||
SlaveState::RoundReceived(3) => {
|
||||
self.state = SlaveState::RoundReceived(4);
|
||||
retval = self.get_mxtables_for_block_aggr().concat();
|
||||
is_complete = true;
|
||||
}
|
||||
_ => {
|
||||
return Err(GhashError::OutOfOrder);
|
||||
}
|
||||
}
|
||||
if is_complete {
|
||||
if self.state != SlaveState::RoundReceived(4) {
|
||||
// if the last round was not round 4 (i.e. there was no block
|
||||
// aggregation), then we compute GHASH directly
|
||||
self.c.temp_share =
|
||||
Some(multiply_powers_and_blocks(&self.c.powers, &self.c.blocks));
|
||||
}
|
||||
self.state = SlaveState::Complete;
|
||||
}
|
||||
Ok(retval)
|
||||
}
|
||||
|
||||
fn finalize(&mut self) -> Result<u128, GhashError> {
|
||||
if self.state != SlaveState::Complete {
|
||||
return Err(GhashError::FinalizeCalledTooEarly);
|
||||
}
|
||||
Ok(self.c.temp_share.unwrap())
|
||||
}
|
||||
|
||||
fn is_complete(&mut self) -> bool {
|
||||
self.state == SlaveState::Complete
|
||||
}
|
||||
|
||||
fn export_powers(&mut self) -> BTreeMap<u16, u128> {
|
||||
self.c.export_powers()
|
||||
}
|
||||
|
||||
fn calculate_ot_count(&mut self) -> usize {
|
||||
self.c.calculate_ot_count()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Rng + CryptoRng> GhashSlave<R> {
|
||||
pub fn new(rng: R, ghash_key_share: u128, blocks: Vec<u128>) -> Result<Self, GhashError> {
|
||||
let c = GhashCommon::new(ghash_key_share, blocks)?;
|
||||
Ok(Self {
|
||||
c,
|
||||
rng,
|
||||
state: SlaveState::Initialized,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_round2_needed(&self) -> bool {
|
||||
self.c.is_round2_needed()
|
||||
}
|
||||
|
||||
fn is_round3_needed(&self) -> bool {
|
||||
self.c.is_round3_needed()
|
||||
}
|
||||
|
||||
fn is_only_1_round(&self) -> bool {
|
||||
self.c.is_only_1_round()
|
||||
}
|
||||
|
||||
/// Returns the masked X table for round 1.
|
||||
/// Since the Master (M) sends bits for his powers in ascending order, we need to
|
||||
/// accomodate that order, i.e. if we need to multiply M's H^1 by
|
||||
/// our H^2 and then multiply M's H^2 by our H^1, then we return [mxtable
|
||||
/// for H^2 + mxtable for H^1].
|
||||
fn get_mxtables_for_round1(&mut self) -> Vec<MXTableFull> {
|
||||
self.c.powers.insert(2, free_square(self.c.powers[&1]));
|
||||
let (masked1, h3_share1) = masked_xtable(&mut self.rng, self.c.powers[&1]);
|
||||
let (masked2, h3_share2) = masked_xtable(&mut self.rng, self.c.powers[&2]);
|
||||
|
||||
self.c.powers.insert(
|
||||
3,
|
||||
block_mult(self.c.powers[&1], self.c.powers[&2]) ^ h3_share1 ^ h3_share2,
|
||||
);
|
||||
// since we just added a new share of powers, we need to square them
|
||||
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
|
||||
vec![masked2, masked1]
|
||||
}
|
||||
|
||||
/// Returns masked X tables for either round 2 or round 3.
|
||||
fn get_mxtables_for_round(&mut self, round_no: u8) -> Vec<MXTableFull> {
|
||||
assert!(round_no == 2 || round_no == 3);
|
||||
let mut all_mxtables: Vec<MXTableFull> = Vec::new();
|
||||
for (key, value) in self.c.strategies[(round_no - 2) as usize].clone().iter() {
|
||||
if *key > self.c.max_odd_power {
|
||||
break;
|
||||
}
|
||||
// Since Master sends bits in ascending order: factor1 bits,
|
||||
// factor2 bits, we must return mxtables in descending order:
|
||||
// factor2 mxtable, factor1 mxtable.
|
||||
let factor1 = self.c.powers[&(value[0] as u16)];
|
||||
let factor2 = self.c.powers[&(value[1] as u16)];
|
||||
let (mxtable1, sum1) = masked_xtable(&mut self.rng, factor1);
|
||||
let (mxtable2, sum2) = masked_xtable(&mut self.rng, factor2);
|
||||
all_mxtables.push(mxtable2);
|
||||
all_mxtables.push(mxtable1);
|
||||
|
||||
// our share of power <key> is the locally computed term plus sums
|
||||
// of masks of each cross-term.
|
||||
let local_term = block_mult(factor1, factor2);
|
||||
self.c.powers.insert(*key as u16, local_term ^ sum1 ^ sum2);
|
||||
}
|
||||
// since we just added a few new shares of powers, we need to square them
|
||||
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
|
||||
all_mxtables
|
||||
}
|
||||
|
||||
/// Returns masked X tables for the block aggregation method.
|
||||
fn get_mxtables_for_block_aggr(&mut self) -> Vec<MXTableFull> {
|
||||
let share1 = multiply_powers_and_blocks(&self.c.powers, &self.c.blocks);
|
||||
let (aggregated, share2) = block_aggregation(&self.c.powers, &self.c.blocks);
|
||||
let (mxtables, share3) =
|
||||
block_aggregation_mxtables(&mut self.rng, &self.c.powers, &aggregated);
|
||||
self.c.temp_share = Some(share1 ^ share2 ^ share3);
|
||||
mxtables
|
||||
}
|
||||
}
|
||||
587
tls-core/src/ghash/utils.rs
Normal file
587
tls-core/src/ghash/utils.rs
Normal file
@@ -0,0 +1,587 @@
|
||||
use crate::ghash::{MXTableFull, YBits};
|
||||
use mpc_core::utils::u8vec_to_boolvec;
|
||||
use rand::{CryptoRng, Rng};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// R is GCM polynomial in little-endian. In hex: "E1000000000000000000000000000000"
|
||||
const R: u128 = 299076299051606071403356588563077529600;
|
||||
|
||||
/// Galois field multiplication of two 128-bit blocks reduced by the GCM polynomial
|
||||
pub fn block_mult(mut x: u128, y: u128) -> u128 {
|
||||
let mut result: u128 = 0;
|
||||
for i in (0..128).rev() {
|
||||
result ^= x * ((y >> i) & 1);
|
||||
x = (x >> 1) ^ ((x & 1) * R);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns the squared value. It is called "free" because due to the
|
||||
/// property of Galois multiplication, squaring can be done locally without
|
||||
/// the need for 2PC.
|
||||
pub fn free_square(x: u128) -> u128 {
|
||||
block_mult(x, x)
|
||||
}
|
||||
|
||||
/// Performs squaring of each odd power in "powers" up to and including
|
||||
/// the maximum power "max" and returns an updated map of powers. Squaring
|
||||
/// will be done recursively if needed, e.g if we have power == 1 and "max" is 22,
|
||||
/// then 1 will be squared to get power == 2, then 2 -> 4, 4 -> 8, 8 -> 16.
|
||||
/// Those powers which have already been squared will be skipped.
|
||||
pub fn square_all(powers: &BTreeMap<u16, u128>, max: u16) -> BTreeMap<u16, u128> {
|
||||
let mut new_powers: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
for (power, value) in powers.iter() {
|
||||
// The fact the we had earlier computed more powers that we will ever
|
||||
// need is a sign of a logic error which needs to be investigated.
|
||||
assert!(*power <= max);
|
||||
new_powers.insert(*power, *value);
|
||||
if power % 2 == 0 {
|
||||
continue;
|
||||
}
|
||||
// existing_power is the power for which we have the value.
|
||||
let mut existing_power = *power;
|
||||
while existing_power * 2 <= max {
|
||||
// check if the squaring has already been done, otherwise do it now.
|
||||
let option = powers.get(&(existing_power * 2));
|
||||
let squared_value: u128;
|
||||
if option == None {
|
||||
let value_to_square = new_powers.get(&existing_power).unwrap();
|
||||
squared_value = free_square(*value_to_square);
|
||||
} else {
|
||||
squared_value = *option.unwrap();
|
||||
}
|
||||
new_powers.insert(existing_power * 2, squared_value);
|
||||
existing_power *= 2;
|
||||
}
|
||||
}
|
||||
new_powers
|
||||
}
|
||||
|
||||
/// Finds 2 non-equal summands which add up to the needed sum. The
|
||||
/// first returned summand will be as small as possible.
|
||||
/// E.g if "summands" keys are 1,2,3,5,6 and "sum_needed" is 8, then
|
||||
/// the returned value would be (2,6).
|
||||
pub fn find_sum(summands: &BTreeMap<u16, u128>, sum_needed: u16) -> (u16, u16) {
|
||||
for (i, _) in summands.iter() {
|
||||
for (j, _) in summands.iter() {
|
||||
if *j == *i {
|
||||
continue;
|
||||
}
|
||||
if *i + *j == sum_needed {
|
||||
return (*i, *j);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Should never get here. We only call find_sum when we know in advance
|
||||
// that summands will be found.
|
||||
panic!("summands were not found")
|
||||
}
|
||||
|
||||
/// Returns the maximum odd power that we'll need to compute GHASH in 2PC
|
||||
/// using Block Aggregation, where "max" is maximum power for GHASH
|
||||
/// (i.e it is the amount of GHASH blocks).
|
||||
pub fn find_max_odd_power(max: u16) -> u8 {
|
||||
assert!(max <= 1026);
|
||||
// max_htable's <value> shows how many GHASH blocks can be processed
|
||||
// with Block Aggregation if we have all the sequential shares
|
||||
// starting with 1 up to and including <key>.
|
||||
// e.g. (5, 29) means that if we have shares of H^1,H^2,H^3,H^4,H^5,
|
||||
// then we can process 29 GHASH blocks.
|
||||
// max TLS record size of 16KB requires 1026 GHASH blocks
|
||||
let max_htable: BTreeMap<u8, u16> = BTreeMap::from([
|
||||
(0, 0),
|
||||
(3, 19),
|
||||
(5, 29),
|
||||
(7, 71),
|
||||
(9, 89),
|
||||
(11, 107),
|
||||
(13, 125),
|
||||
(15, 271),
|
||||
(17, 305),
|
||||
(19, 339),
|
||||
(21, 373),
|
||||
(23, 407),
|
||||
(25, 441),
|
||||
(27, 475),
|
||||
(29, 509),
|
||||
(31, 1023),
|
||||
(33, 1025),
|
||||
(35, 1027),
|
||||
]);
|
||||
let mut out = 0u8;
|
||||
for (key, value) in max_htable.iter() {
|
||||
if *value >= max {
|
||||
out = *key;
|
||||
break;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Multiplies GHASH blocks by the corresponding shares of powers of H and
|
||||
/// returns the sum of all products. If some share is not present, the
|
||||
/// corresponding block is not multiplied at this stage but it will later
|
||||
/// participate in block aggregation.
|
||||
pub fn multiply_powers_and_blocks(powers: &BTreeMap<u16, u128>, blocks: &Vec<u128>) -> u128 {
|
||||
let last_key = *powers.iter().last().unwrap().0;
|
||||
assert!(last_key as usize <= blocks.len());
|
||||
let mut sum = 0u128;
|
||||
for (power, value) in powers.iter() {
|
||||
// in GHASH, H^1 is multiplied with the last block, H^2 with the second to last
|
||||
// block, etc.
|
||||
sum ^= block_mult(*value, blocks[blocks.len() - (*power as usize)]);
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
/// Implements the block aggregation method.
|
||||
pub fn block_aggregation(
|
||||
powers: &BTreeMap<u16, u128>,
|
||||
blocks: &Vec<u128>,
|
||||
) -> (BTreeMap<u16, u128>, u128) {
|
||||
let mut ghash_share = 0u128;
|
||||
let mut aggregated: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
for i in 1..blocks.len() + 1 {
|
||||
if powers.get(&(i as u16)) != None {
|
||||
// we already multiplied the block with this share of power in
|
||||
// multiply_powers_and_blocks()
|
||||
continue;
|
||||
}
|
||||
// else we found a power of H which we don't have.
|
||||
let (small, big) = find_sum(&powers, i as u16);
|
||||
let block = blocks[blocks.len() - i];
|
||||
ghash_share ^= block_mult(
|
||||
block_mult(*powers.get(&small).unwrap(), *powers.get(&big).unwrap()),
|
||||
block,
|
||||
);
|
||||
// initialize the value if it doesn't exist
|
||||
if aggregated.get(&small) == None {
|
||||
aggregated.insert(small, 0u128);
|
||||
}
|
||||
// update value
|
||||
let old_value = *aggregated.get(&small).unwrap();
|
||||
aggregated.insert(
|
||||
small,
|
||||
old_value ^ block_mult(*powers.get(&big).unwrap(), block),
|
||||
);
|
||||
}
|
||||
(aggregated, ghash_share)
|
||||
}
|
||||
|
||||
/// Returns YBits which Master needs to complete Block Aggregation.
|
||||
pub fn block_aggregation_bits(
|
||||
powers: &BTreeMap<u16, u128>,
|
||||
aggregated: &BTreeMap<u16, u128>,
|
||||
) -> Vec<YBits> {
|
||||
let mut all_bits: Vec<YBits> = Vec::new();
|
||||
for (power, value) in aggregated.iter() {
|
||||
// Master sends first bits of power then bits of value. Slave sends
|
||||
// masked x tables in reverse order.
|
||||
all_bits.push(u8vec_to_boolvec(
|
||||
&(*powers.get(power).unwrap()).to_be_bytes(),
|
||||
));
|
||||
all_bits.push(u8vec_to_boolvec(&value.to_be_bytes()));
|
||||
}
|
||||
all_bits
|
||||
}
|
||||
|
||||
/// Returns masked X tables which Slave needs to complete Block Aggregation.
|
||||
pub fn block_aggregation_mxtables<R: Rng + CryptoRng>(
|
||||
rng: &mut R,
|
||||
powers: &BTreeMap<u16, u128>,
|
||||
aggregated: &BTreeMap<u16, u128>,
|
||||
) -> (Vec<MXTableFull>, u128) {
|
||||
let mut all_mxtables: Vec<MXTableFull> = Vec::new();
|
||||
let mut sum = 0u128;
|
||||
for (power, value) in aggregated.iter() {
|
||||
// Slave sends first masked x table of agregated value then masked x
|
||||
// table of power value.
|
||||
let (mxtable1, sum1) = masked_xtable(rng, *value);
|
||||
let (mxtable2, sum2) = masked_xtable(rng, *powers.get(power).unwrap());
|
||||
sum ^= sum1 ^ sum2;
|
||||
all_mxtables.push(mxtable1);
|
||||
all_mxtables.push(mxtable2);
|
||||
}
|
||||
(all_mxtables, sum)
|
||||
}
|
||||
|
||||
/// Returns a table of values of x after each of the 128 rounds of blockMult()
|
||||
fn xtable(mut x: u128) -> Vec<u128> {
|
||||
let mut x_table: Vec<u128> = vec![0u128; 128];
|
||||
for i in 0..128 {
|
||||
x_table[i] = x;
|
||||
x = (x >> 1) ^ ((x & 1) * R);
|
||||
}
|
||||
x_table
|
||||
}
|
||||
|
||||
/// Returns:
|
||||
/// 1) a masked xTable from which OT response will be constructed and
|
||||
/// 2) the XOR-sum of all masks which is our share of the block multiplication product
|
||||
/// For each value of xTable, the masked xTable will contain 2 values:
|
||||
/// 1) a random mask and
|
||||
/// 2) the xTable entry masked with the random mask.
|
||||
pub fn masked_xtable<R: Rng + CryptoRng>(rng: &mut R, x: u128) -> (MXTableFull, u128) {
|
||||
let x_table = xtable(x);
|
||||
// maskSum is the xor sum of all masks
|
||||
let mut mask_sum: u128 = 0;
|
||||
let mut masked_xtable: MXTableFull = vec![[0u128; 2]; 128];
|
||||
for i in 0..128 {
|
||||
let mask: u128 = rng.gen();
|
||||
mask_sum ^= mask;
|
||||
masked_xtable[i][0] = mask;
|
||||
masked_xtable[i][1] = x_table[i] ^ mask;
|
||||
}
|
||||
(masked_xtable, mask_sum)
|
||||
}
|
||||
|
||||
/// Returns the XOR sum of all elements of the vector.
|
||||
pub fn xor_sum(vec: &Vec<u128>) -> u128 {
|
||||
vec.iter().fold(0u128, |acc, x| acc ^ x)
|
||||
}
|
||||
|
||||
/// Converts a flat vector into a vector of chunks of the needed size.
|
||||
pub fn flat_to_chunks<T>(flat: &Vec<T>, chunk_size: usize) -> Vec<Vec<T>>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
let count = flat.len() / chunk_size;
|
||||
let mut vec_chunks: Vec<Vec<T>> = Vec::with_capacity(count);
|
||||
for chunk in flat.chunks(chunk_size) {
|
||||
vec_chunks.push(chunk.to_vec());
|
||||
}
|
||||
vec_chunks
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ghash::{
|
||||
universal_hash::{NewUniversalHash, UniversalHash},
|
||||
GHash,
|
||||
};
|
||||
use rand::SeedableRng;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand_chacha::ChaCha12Rng;
|
||||
use std::convert::TryInto;
|
||||
|
||||
#[test]
|
||||
fn test_block_mult() {
|
||||
let mut rng = thread_rng();
|
||||
let x: u128 = rng.gen();
|
||||
let y: u128 = rng.gen();
|
||||
assert_eq!(block_mult(x, y), rust_crypto_ghash(x, &vec![y]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_free_square() {
|
||||
let mut rng = thread_rng();
|
||||
let x: u128 = rng.gen();
|
||||
assert_eq!(free_square(x), rust_crypto_ghash(x, &vec![x]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_square_all() {
|
||||
let mut powers_keys: Vec<u16> = vec![1, 3, 5, 6];
|
||||
let mut max_power = 8;
|
||||
let mut powers = setup_square_all(powers_keys);
|
||||
let mut res = square_all(&powers, max_power);
|
||||
let mut res_keys: Vec<u16> = res.keys().cloned().collect();
|
||||
let mut res_values: Vec<u128> = res.values().cloned().collect();
|
||||
assert_eq!(res_keys, vec![1, 2, 3, 4, 5, 6, 8]);
|
||||
assert_eq!(
|
||||
res_values,
|
||||
vec![
|
||||
12346,
|
||||
37384537381450758925419573570612497787,
|
||||
12348,
|
||||
330354857586696702251049094163216396600,
|
||||
12350,
|
||||
12351,
|
||||
226635212255694396134298606764692048245
|
||||
]
|
||||
);
|
||||
|
||||
powers_keys = vec![1, 101];
|
||||
max_power = 255;
|
||||
powers = setup_square_all(powers_keys);
|
||||
res = square_all(&powers, max_power);
|
||||
res_keys = res.keys().cloned().collect();
|
||||
res_values = res.values().cloned().collect();
|
||||
assert_eq!(res_keys, vec![1, 2, 4, 8, 16, 32, 64, 101, 128, 202]);
|
||||
assert_eq!(
|
||||
res_values,
|
||||
vec![
|
||||
12346,
|
||||
37384537381450758925419573570612497787,
|
||||
330354857586696702251049094163216396600,
|
||||
226635212255694396134298606764692048245,
|
||||
57435732663173249101181542372863021389,
|
||||
221430791223604693286277902449016542898,
|
||||
102715545892785352441374614451375486130,
|
||||
12446,
|
||||
308951395680463668816102109032313818778,
|
||||
144387391042136486694176041923180290643
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_square_all_should_panic() {
|
||||
let powers_keys: Vec<u16> = vec![1, 3, 5, 6, 9];
|
||||
let max_power = 8;
|
||||
let powers = setup_square_all(powers_keys);
|
||||
// will panick because we have power 9 when max needed is 8
|
||||
square_all(&powers, max_power);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_sum() {
|
||||
let summands = setup_find_sum();
|
||||
assert_eq!(find_sum(&summands, 8), (3, 5));
|
||||
assert_eq!(find_sum(&summands, 14), (5, 9));
|
||||
assert_eq!(find_sum(&summands, 21), (6, 15));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_find_sum_should_panic() {
|
||||
let summands = setup_find_sum();
|
||||
// will panick because summands must be non-equal, so (15, 15) is not
|
||||
// allowed
|
||||
find_sum(&summands, 30);
|
||||
// will panick because no two summands add up to 25
|
||||
find_sum(&summands, 25);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_max_odd_power() {
|
||||
assert_eq!(find_max_odd_power(1), 3);
|
||||
assert_eq!(find_max_odd_power(20), 5);
|
||||
assert_eq!(find_max_odd_power(100), 11);
|
||||
assert_eq!(find_max_odd_power(1000), 31);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiply_powers_and_blocks() {
|
||||
let mut rng = thread_rng();
|
||||
let block_count = 10;
|
||||
let h: u128 = rng.gen();
|
||||
let powers_of_h = compute_expected_powers(h, 10);
|
||||
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
let mut blocks: Vec<u128> = Vec::new();
|
||||
// init all powers and all blocks
|
||||
for i in 0..block_count {
|
||||
powers.insert(i + 1, powers_of_h[(i + 1) as usize]);
|
||||
blocks.push(rng.gen());
|
||||
}
|
||||
let result = multiply_powers_and_blocks(&powers, &blocks);
|
||||
assert_eq!(result, rust_crypto_ghash(h, &blocks));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_multiply_powers_and_blocks_should_panic() {
|
||||
let mut rng = thread_rng();
|
||||
let block_count = 10;
|
||||
let h: u128 = rng.gen();
|
||||
let powers_of_h = compute_expected_powers(h, block_count);
|
||||
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
let mut blocks: Vec<u128> = Vec::new();
|
||||
// init all powers and all blocks
|
||||
for i in 0..block_count {
|
||||
powers.insert(i + 1, powers_of_h[(i + 1) as usize]);
|
||||
blocks.push(rng.gen());
|
||||
}
|
||||
// insert an extra power to have more powers than blocks which is a
|
||||
// sign of a logic error.
|
||||
powers.insert(
|
||||
block_count + 1,
|
||||
block_mult(powers_of_h[block_count as usize], h),
|
||||
);
|
||||
multiply_powers_and_blocks(&powers, &blocks);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_aggregation() {
|
||||
let h: u128 = 123456;
|
||||
let block_count = 10;
|
||||
let powers_of_h = compute_expected_powers(h, block_count);
|
||||
let mut powers_map: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
let mut blocks: Vec<u128> = Vec::with_capacity(block_count as usize);
|
||||
for i in 1..block_count + 1 {
|
||||
powers_map.insert(i, powers_of_h[i as usize]);
|
||||
blocks.push(1234567 + i as u128);
|
||||
}
|
||||
// all powers are in place, so no block aggregation will happen
|
||||
assert_eq!(
|
||||
block_aggregation(&powers_map, &blocks),
|
||||
(BTreeMap::new(), 0)
|
||||
);
|
||||
// remove some powers
|
||||
powers_map.remove(&5);
|
||||
powers_map.remove(&7);
|
||||
let mut expected_map: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
expected_map.insert(1, 6529972824624832318862907648013721286);
|
||||
assert_eq!(
|
||||
block_aggregation(&powers_map, &blocks),
|
||||
(expected_map, 315833047958356732231847338615588728787)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_aggregation_bits() {
|
||||
let mut powers_map: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
let mut aggregated_map: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
powers_map.insert(1, 256);
|
||||
aggregated_map.insert(1, 512);
|
||||
let mut expected: [[bool; 128]; 2] = [[false; 128]; 2];
|
||||
expected[0][128 - 9] = true; // set bit for 256
|
||||
expected[1][128 - 10] = true; // set bit for 512
|
||||
assert_eq!(
|
||||
expected.concat(),
|
||||
block_aggregation_bits(&powers_map, &aggregated_map).concat()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_aggregation_mxtables() {
|
||||
let mut rng = ChaCha12Rng::seed_from_u64(12345);
|
||||
|
||||
let mut powers_map: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
let mut aggregated_map: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
powers_map.insert(1, 256);
|
||||
aggregated_map.insert(1, 512);
|
||||
let mut expected: [[bool; 128]; 2] = [[false; 128]; 2];
|
||||
expected[0][128 - 9] = true; // set bit for 256
|
||||
expected[1][128 - 10] = true; // set bit for 512
|
||||
let (mxtables, share) = block_aggregation_mxtables(&mut rng, &powers_map, &aggregated_map);
|
||||
assert_eq!(share, 119435139769675579125133100514879089925);
|
||||
// since mxtables output is huge, we check only a few arbitrary elements of it
|
||||
assert_eq!(
|
||||
mxtables.concat()[23],
|
||||
[
|
||||
87470858790173581075227934021272140429,
|
||||
87445059565157736150448673117636328077
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
mxtables.concat()[54],
|
||||
[
|
||||
332311817981764475918232627385210634271,
|
||||
332311817981747626514621748491088166943
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
mxtables.concat()[10],
|
||||
[
|
||||
111226209635646018502889366408372405024,
|
||||
237502869235213026428751037135005139744
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xtable() {
|
||||
let result = xtable(123456u128);
|
||||
// since xtable output is huge, we check only a few arbitrary elements of it
|
||||
assert_eq!(result[0], 123456);
|
||||
assert_eq!(result[39], 57076772936301564299567746252800);
|
||||
assert_eq!(result[111], 12086476800);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_masked_xtable() {
|
||||
let mut rng = thread_rng();
|
||||
let x: u128 = rng.gen();
|
||||
let y: u128 = rng.gen();
|
||||
let expected = block_mult(x, y);
|
||||
assert_eq!(expected, product_from_shares(x, y));
|
||||
|
||||
// corrupt some bytes of y value
|
||||
let mut bad_bytes = y.to_be_bytes();
|
||||
bad_bytes[5] = (bad_bytes[5] + 1) / 255;
|
||||
bad_bytes[10] = (bad_bytes[10] + 1) / 255;
|
||||
bad_bytes[15] = (bad_bytes[15] + 1) / 255;
|
||||
let bad = u128::from_be_bytes(bad_bytes);
|
||||
assert_ne!(expected, product_from_shares(x, bad));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xor_sum() {
|
||||
let mut rng = thread_rng();
|
||||
let mut summands: Vec<u128> = Vec::new();
|
||||
for _i in 0..300 {
|
||||
let rand = rng.gen();
|
||||
summands.push(rand);
|
||||
summands.push(rand);
|
||||
}
|
||||
// xoring the same value twice should result in zero
|
||||
assert_eq!(xor_sum(&summands), 0);
|
||||
summands.push(123456);
|
||||
assert_eq!(xor_sum(&summands), 123456);
|
||||
}
|
||||
|
||||
// compute GHASH using RustCrypto's ghash
|
||||
fn rust_crypto_ghash(h: u128, blocks: &Vec<u128>) -> u128 {
|
||||
let mut ghash = GHash::new(&h.to_be_bytes().into());
|
||||
for block in blocks.iter() {
|
||||
ghash.update(&block.to_be_bytes().into());
|
||||
}
|
||||
let b = ghash.finalize().into_bytes();
|
||||
u128::from_be_bytes(b.as_slice().try_into().unwrap())
|
||||
}
|
||||
|
||||
// prepare the expected powers of h by recursively multiplying h to
|
||||
// itself
|
||||
fn compute_expected_powers(h: u128, max: u16) -> Vec<u128> {
|
||||
// prepare the expected powers of h by recursively multiplying h to
|
||||
// itself
|
||||
let mut powers: Vec<u128> = vec![0u128; (max + 1) as usize];
|
||||
powers[1] = h;
|
||||
let mut prev_power = h;
|
||||
for i in 2..((max + 1) as usize) {
|
||||
powers[i] = block_mult(prev_power, h);
|
||||
prev_power = powers[i];
|
||||
}
|
||||
powers
|
||||
}
|
||||
|
||||
fn setup_find_sum() -> BTreeMap<u16, u128> {
|
||||
let summands_keys: Vec<u16> = vec![1, 3, 5, 6, 8, 9, 12, 15];
|
||||
let mut summands: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
// assign any value > 0 to elements at keys corresponding to
|
||||
// summands_keys
|
||||
for v in summands_keys.iter() {
|
||||
summands.insert(*v, (12345 + *v) as u128);
|
||||
}
|
||||
summands
|
||||
}
|
||||
|
||||
fn setup_square_all(keys: Vec<u16>) -> BTreeMap<u16, u128> {
|
||||
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
|
||||
// assign any value to elements at keys corresponding to
|
||||
// powers_keys
|
||||
for v in keys.iter() {
|
||||
powers.insert(*v, (12345 + *v) as u128);
|
||||
}
|
||||
powers
|
||||
}
|
||||
|
||||
fn product_from_shares(x: u128, y: u128) -> u128 {
|
||||
// instantiate with empty values, we only need rng for this test
|
||||
let (masked_xtable, my_product_share) = masked_xtable(&mut thread_rng(), x);
|
||||
|
||||
// the other party who has the y value will receive only 1 value (out of 2)
|
||||
// for each entry in maskedXTable via Oblivious Transfer depending on the
|
||||
// bits of y. We simulate that here:
|
||||
let mut his_product_share = 0u128;
|
||||
let bits = u8vec_to_boolvec(&y.to_be_bytes());
|
||||
for i in 0..128 {
|
||||
// the first element in xTable corresponds to the highest bit of y
|
||||
his_product_share ^= masked_xtable[i][bits[i] as usize];
|
||||
}
|
||||
my_product_share ^ his_product_share
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,4 @@
|
||||
#[cfg(feature = "ghash")]
|
||||
pub mod ghash;
|
||||
#[cfg(feature = "prf")]
|
||||
pub mod prf;
|
||||
|
||||
Reference in New Issue
Block a user