implement GHASH 2PC (#20)

* implement GHASH 2PC

* receiver_tls test

* make GHASH more generic

* wrap temp_share in an Option<>

* fix warnings

* make modules public

* add traits

* add feature

Co-authored-by: themighty1 <you@example.com>
This commit is contained in:
Dan
2022-04-28 19:19:42 +00:00
committed by GitHub
parent e213aa8755
commit 99041da58b
8 changed files with 1481 additions and 1 deletions

View File

@@ -8,14 +8,19 @@ name = "tls_core"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = ["prf"]
default = ["prf", "ghash"]
prf = []
ghash = []
[dependencies]
tlsn-mpc-core = { path = "../mpc-core" }
sha2 = { version = "0.10.1", features = ["compress"] }
digest = { version = "0.10.3" }
hmac = { version = "0.12.1" }
rand = "0.8.5"
thiserror = "1.0.30"
[dev-dependencies]
criterion = "0.3.5"
ghash_rc = { package = "ghash", version = "0.4.4" }
rand_chacha = "0.3.1"

View File

@@ -0,0 +1,151 @@
use super::errors::*;
use super::utils::{find_max_odd_power, find_sum, square_all};
use std::collections::BTreeMap;
// GhashCommon is common to both Master and Slave
pub struct GhashCommon {
// blocks are input blocks for GHASH. In TLS the first block is AAD,
// the middle blocks are AES blocks - the ciphertext, the last block
// is len(AAD)+len(ciphertext)
pub blocks: Vec<u128>,
// powers are our XOR shares of the powers of H (H is the GHASH key).
// We need as many powers as there blocks. Value at key==1 corresponds to the share
// of H^1, value at key==2 to the share of H^2 etc.
pub powers: BTreeMap<u16, u128>,
// max_odd_power is the maximum odd power that we'll need to compute
// GHASH in 2PC using Block Aggregation
pub max_odd_power: u8,
// strategies are initialized in ::new(). See comments there.
pub strategies: [BTreeMap<u8, [u8; 2]>; 2],
// temp_share is used to save an intermediate GHASH share
pub temp_share: Option<u128>,
}
impl GhashCommon {
pub fn new(ghash_key_share: u128, blocks: Vec<u128>) -> Result<Self, GhashError> {
if blocks.len() < 3 || blocks.len() > 1026 {
return Err(GhashError::BlockCountWrong);
}
let mut powers = BTreeMap::new();
// GHASH key is our share H^1
powers.insert(1, ghash_key_share);
let max_odd_power = find_max_odd_power(blocks.len() as u16);
powers = square_all(&powers, blocks.len() as u16);
// strategy1 and startegy2 are only relevant for the Block Aggregation method.
// They show what existing shares of the powers of H (H is the GHASH key) we
// will be multiplying (value[0] and value[1]) to obtain other odd shares (<key>).
// Max sequential odd share that we can obtain on first round of
// communication is 19. We already have 1) shares of H^1, H^2, H^3 from
// the Client Finished message and 2) squares of those 3 shares.
// Note that "sequential" is a keyword here. We can't obtain 21 but we
// indeed can obtain 25==24+1, 33==32+1 etc. However with 21 missing,
// even if we have 25,33,etc, there will be a gap and we will not be able
// to obtain all the needed shares by Block Aggregation.
// We request OT for each share in each pair of the strategy, i.e. for
// shares: 4,1,4,3,8,1, etc. Even though it would be possible to introduce
// optimizations in order to avoid requesting OT for the same share more
// than once, that would only save us ~2000 OT instances at the cost of
// complicating the code.
let strategy1: BTreeMap<u8, [u8; 2]> = BTreeMap::from([
(5, [4, 1]),
(7, [4, 3]),
(9, [8, 1]),
(11, [8, 3]),
(13, [12, 1]),
(15, [12, 3]),
(17, [16, 1]),
(19, [16, 3]),
]);
let strategy2: BTreeMap<u8, [u8; 2]> = BTreeMap::from([
(21, [17, 4]),
(23, [17, 6]),
(25, [17, 8]),
(27, [19, 8]),
(29, [17, 12]),
(31, [19, 12]),
(33, [17, 16]),
(35, [19, 16]),
]);
Ok(Self {
max_odd_power,
blocks,
powers,
strategies: [strategy1, strategy2],
temp_share: Some(0u128),
})
}
/// Returns the amount of Oblivious Transfer instances needed to complete
/// the protocol. The purpose is to inform the OT layer.
pub fn calculate_ot_count(&mut self) -> usize {
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
// only 2 2PC multiplications in round 1
let r1count = 2;
powers.insert(1, 0u128);
powers.insert(2, 0u128);
powers.insert(3, 0u128);
// since we just added a few new shares of powers, we need to square them
self.powers = square_all(&powers, self.blocks.len() as u16);
// merge 2 strategies maps into 1
let strategy: BTreeMap<u8, [u8; 2]> = self.strategies[0]
.clone()
.into_iter()
.chain(self.strategies[1].clone())
.collect();
// number of multiplications in rounds 2 and 3
let mut r2and3count = 0;
for (key, _) in strategy.iter() {
if *key > self.max_odd_power {
break;
};
powers.insert(*key as u16, 0u128);
r2and3count += 2;
}
// since we just added a few new shares of powers, we need to square them
powers = square_all(&powers, self.blocks.len() as u16);
let mut aggregated: BTreeMap<u16, u128> = BTreeMap::new();
for i in 1..self.blocks.len() + 1 {
if powers.get(&(i as u16)) != None {
continue;
}
let (small, _) = find_sum(&powers, i as u16);
// initialize the value if it doesn't exist
if aggregated.get(&small) == None {
aggregated.insert(small, 0u128);
}
}
let r4count = aggregated.len() * 2;
// each multiplication requires 128 instances of Oblivious Transfer
(r1count + r2and3count + r4count) * 128
}
pub fn export_powers(&mut self) -> BTreeMap<u16, u128> {
self.powers.clone()
}
pub fn is_round2_needed(&self) -> bool {
// after round 1 we will have consecutive powers 1,2,3 which is enough
// to compute GHASH for 19 blocks with block aggregation.
self.blocks.len() > 19
}
pub fn is_round3_needed(&self) -> bool {
// after round 2 we will have a max of up to 19 consequitive odd powers
// which allows us to get 339 powers with block aggregation, see max_htable
// in utils::find_max_odd_power()
self.blocks.len() > 339
}
pub fn is_only_1_round(&self) -> bool {
// block agregation is always used except for very small block count
// where powers from round 1 are sufficient to perform direct multiplication
// of blocks by powers
self.blocks.len() <= 4
}
}

View File

@@ -0,0 +1,12 @@
/// Errors that may occur when using ghash module
#[derive(Debug, thiserror::Error)]
pub enum GhashError {
#[error("Message was received out of order")]
OutOfOrder,
#[error("The other party sent data of wrong size")]
DataLengthWrong,
#[error("Tried to pass unsupported block count")]
BlockCountWrong,
#[error("Tried to finalize before the protocol was complete")]
FinalizeCalledTooEarly,
}

View File

@@ -0,0 +1,245 @@
//! Implements the GHASH Master. This is the party which holds the Y value of
//! block multiplication. Master acts as the receiver of the Oblivious
//! Transfer and receives Slaves's masked X table entries obliviously for each
//! bit of Y.
use super::utils::{
block_aggregation, block_aggregation_bits, block_mult, flat_to_chunks,
multiply_powers_and_blocks, square_all, xor_sum,
};
use super::{errors::*, MasterCore};
use crate::ghash::common::GhashCommon;
use crate::ghash::{MXTable, YBits};
use mpc_core::utils::u8vec_to_boolvec;
use std::collections::BTreeMap;
#[derive(PartialEq)]
pub enum MasterState {
Initialized,
// There may be 1, 2, 3 or 4 rounds depending on GHASH block count
RoundSent(usize),
RoundReceived(usize),
Complete,
}
pub struct GhashMaster {
c: GhashCommon,
state: MasterState,
// is_last_round will be set to true by next_request() to indicate that
// after the response is received the state must be set to Complete
is_last_round: bool,
}
impl MasterCore for GhashMaster {
fn next_request(&mut self) -> Result<Vec<bool>, GhashError> {
let retval;
let is_complete;
match self.state {
MasterState::Initialized => {
self.state = MasterState::RoundSent(1);
retval = self.get_ybits_for_round1().concat();
is_complete = self.is_only_1_round();
}
MasterState::RoundReceived(1) => {
if self.is_round2_needed() {
self.state = MasterState::RoundSent(2);
retval = self.get_ybits_for_round(2).concat();
is_complete = false;
} else {
// rounds 2 and 3 will be skipped
self.state = MasterState::RoundSent(4);
retval = self.get_ybits_for_block_aggr().concat();
is_complete = true;
}
}
MasterState::RoundReceived(2) => {
if self.is_round3_needed() {
self.state = MasterState::RoundSent(3);
retval = self.get_ybits_for_round(3).concat();
is_complete = false;
} else {
// round 3 will be skipped
self.state = MasterState::RoundSent(4);
retval = self.get_ybits_for_block_aggr().concat();
is_complete = true;
}
}
MasterState::RoundReceived(3) => {
self.state = MasterState::RoundSent(4);
retval = self.get_ybits_for_block_aggr().concat();
is_complete = true;
}
_ => {
return Err(GhashError::OutOfOrder);
}
}
if is_complete {
self.is_last_round = true;
}
Ok(retval)
}
fn process_response(&mut self, response: &Vec<u128>) -> Result<(), GhashError> {
if response.len() % 128 != 0 {
return Err(GhashError::DataLengthWrong);
}
let mxtables = flat_to_chunks(response, 128);
match self.state {
MasterState::RoundSent(1) => {
self.state = MasterState::RoundReceived(1);
self.process_mxtables_for_round1(&mxtables);
}
MasterState::RoundSent(2) => {
self.state = MasterState::RoundReceived(2);
self.process_mxtables_for_round(&mxtables, 2);
}
MasterState::RoundSent(3) => {
self.state = MasterState::RoundReceived(3);
self.process_mxtables_for_round(&mxtables, 3);
}
MasterState::RoundSent(4) => {
self.state = MasterState::RoundReceived(4);
self.process_mxtables_for_block_aggr(&mxtables);
}
_ => {
return Err(GhashError::OutOfOrder);
}
}
if self.is_last_round {
if self.state != MasterState::RoundReceived(4) {
// if the last round was not round 4 (i.e. there was no block
// aggregation), then we compute GHASH directly
self.c.temp_share =
Some(multiply_powers_and_blocks(&self.c.powers, &self.c.blocks));
}
self.state = MasterState::Complete;
}
Ok(())
}
fn is_complete(&mut self) -> bool {
self.state == MasterState::Complete
}
fn finalize(&mut self) -> Result<u128, GhashError> {
if self.state != MasterState::Complete {
return Err(GhashError::FinalizeCalledTooEarly);
}
Ok(self.c.temp_share.unwrap())
}
/// Returns the amount of Oblivious Transfer instances needed to complete
/// the protocol. The purpose is to inform the OT layer.
fn calculate_ot_count(&mut self) -> usize {
self.c.calculate_ot_count()
}
fn export_powers(&mut self) -> BTreeMap<u16, u128> {
self.c.export_powers()
}
}
impl GhashMaster {
pub fn new(ghash_key_share: u128, blocks: Vec<u128>) -> Result<Self, GhashError> {
let common = GhashCommon::new(ghash_key_share, blocks)?;
Ok(Self {
c: common,
state: MasterState::Initialized,
is_last_round: false,
})
}
fn is_round2_needed(&self) -> bool {
self.c.is_round2_needed()
}
fn is_round3_needed(&self) -> bool {
self.c.is_round3_needed()
}
fn is_only_1_round(&self) -> bool {
self.c.is_only_1_round()
}
/// Returns Y bits to compute H^3.
fn get_ybits_for_round1(&mut self) -> Vec<YBits> {
vec![
u8vec_to_boolvec(&self.c.powers[&1].to_be_bytes()),
u8vec_to_boolvec(&self.c.powers[&2].to_be_bytes()),
]
}
/// Takes masked X tables and computes our share of H^3.
fn process_mxtables_for_round1(&mut self, mxtables: &Vec<MXTable>) {
// the XOR sum of all masked xtables' values plus H^1*H^2 is our share of H^3
self.c.powers.insert(
3,
xor_sum(&mxtables[0])
^ xor_sum(&mxtables[1])
^ block_mult(self.c.powers[&1], self.c.powers[&2]),
);
// since we just added a new share of powers, we need to square them
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
}
// Returns Y bits for a given round of communication.
fn get_ybits_for_round(&mut self, round_no: u8) -> Vec<YBits> {
assert!(round_no == 2 || round_no == 3);
let mut bits: Vec<YBits> = Vec::new();
for (key, value) in self.c.strategies[(round_no - 2) as usize].iter() {
if *key > self.c.max_odd_power {
break;
}
bits.push(u8vec_to_boolvec(
&self.c.powers[&(value[0] as u16)].to_be_bytes(),
));
bits.push(u8vec_to_boolvec(
&self.c.powers[&(value[1] as u16)].to_be_bytes(),
));
}
bits
}
/// Processes masked X tables for a given round of communication.
fn process_mxtables_for_round(&mut self, mxtables: &Vec<MXTable>, round_no: u8) {
assert!(round_no == 2 || round_no == 3);
for (count, (power, factors)) in self.c.strategies[(round_no - 2) as usize]
.iter()
.enumerate()
{
if *power > self.c.max_odd_power {
// for every key in the strategy which we processed, there
// must have been 2 masked xtables
assert!(count * 2 == mxtables.len());
break;
}
// the XOR sum of 2 masked xtables' values plus the locally computed
// term is our share of power
let sum = xor_sum(&mxtables[count * 2]) ^ xor_sum(&mxtables[(count * 2) + 1]);
let local_term = block_mult(
self.c.powers[&(factors[0] as u16)],
self.c.powers[&(factors[1] as u16)],
);
self.c.powers.insert(*power as u16, sum ^ local_term);
}
// since we just added a few new shares of powers, we need to square them
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
}
/// Returns Y bits for the block aggregation method.
fn get_ybits_for_block_aggr(&mut self) -> Vec<YBits> {
let share1 = multiply_powers_and_blocks(&self.c.powers, &self.c.blocks);
let (aggregated, share2) = block_aggregation(&self.c.powers, &self.c.blocks);
let choice_bits = block_aggregation_bits(&self.c.powers, &aggregated);
self.c.temp_share = Some(share1 ^ share2);
choice_bits
}
/// Processes masked X tables for the block aggregation method.
fn process_mxtables_for_block_aggr(&mut self, mxtables: &Vec<MXTable>) {
let mut share = 0u128;
for table in mxtables.iter() {
share ^= xor_sum(table);
}
self.c.temp_share = Some(self.c.temp_share.unwrap() ^ share);
}
}

295
tls-core/src/ghash/mod.rs Normal file
View File

@@ -0,0 +1,295 @@
//! ghash implements a protocol of computing the AES-GCM's GHASH function in a
//! secure two-party computation (2PC) setting using 1-out-of-2 Oblivious
//! Transfer (OT). The parties start with their secret XOR shares of H (the
//! GHASH key) and at the end each gets their XOR share of the GHASH output.
//! The method is decribed here:
//! (https://tlsnotary.org/how_it_works#section4).
//! As an illustration, let's say that S has his shares H1_s and H2_s and R
//! has her shares H1_r and H2_r. They need to compute shares of H3.
//! H3 = (H1_s + H1_r)*(H2_s + H2_r) = H1_s*H2_s + H1_s*H2_r + H1_r*H2_s +
//! H1_r*H2_r. Term 1 can be computed by S locally and term 4 can be
//! computed by R locally. Only terms 2 and 3 will be computed using
//! GHASH 2PC. R will obliviously request values for bits of H1_r and H2_r.
//! The XOR sum of all values which S will send back plus H1_r*H2_r will
//! become R's share of H3.
//!
//! When performing block multiplication in 2PC, Master holds the Y value and
//! Slave holds the X value. The Slave then computes the X table and masks it.
mod common;
pub mod errors;
pub mod master;
pub mod slave;
mod utils;
use errors::*;
use std::collections::BTreeMap;
/// MXTableFull is masked XTable which Slave has at the beginning of OT.
/// MXTableFull must not be revealed to Master.
type MXTableFull = Vec<[u128; 2]>;
/// MXTable is a masked x table which Master will end up having after OT.
type MXTable = Vec<u128>;
/// YBits are Master's bits of Y in big-endian. Based on these bits
/// Master will send MXTable via OT.
/// The convention for the returned Y bits:
/// A) powers are in an ascending order: first powers[1], then powers[2] etc.
/// B) bits of each power are in big-endian.
type YBits = Vec<bool>;
pub trait MasterCore {
/// Returns choice bits for Oblivious Transfer.
/// While is_complete() returns false, next_request() must be called
/// followed by process_response().
fn next_request(&mut self) -> Result<Vec<bool>, GhashError>;
/// process_response() will be invoked by the Oblivious Transfer impl. It
/// receives masked X tables acc. to our choice bits in next_request().
fn process_response(&mut self, response: &Vec<u128>) -> Result<(), GhashError>;
/// Returns true when the protocol is complete.
fn is_complete(&mut self) -> bool;
/// Returns our GHASH share.
fn finalize(&mut self) -> Result<u128, GhashError>;
/// Returns the amount of Oblivious Transfer instances needed to complete
/// the protocol. The purpose is to inform the OT layer.
fn calculate_ot_count(&mut self) -> usize;
/// Exports powers of the GHASH key obtained at the current stage of the
/// protocol.
fn export_powers(&mut self) -> BTreeMap<u16, u128>;
}
pub trait SlaveCore {
/// Returns the full masked X table which must NOT be passed to Master. It
/// must be consumed by the Oblivious Transfer impl.
fn process_request(&mut self) -> Result<Vec<[u128; 2]>, GhashError>;
/// Returns true when the protocol is complete.
fn is_complete(&mut self) -> bool;
/// Returns our GHASH share.
fn finalize(&mut self) -> Result<u128, GhashError>;
/// Returns the amount of Oblivious Transfer instances needed to complete
/// the protocol. The purpose is to inform the OT layer.
fn calculate_ot_count(&mut self) -> usize;
/// Exports powers of the GHASH key obtained at the current stage of the
/// protocol.
fn export_powers(&mut self) -> BTreeMap<u16, u128>;
}
#[cfg(test)]
mod tests {
use super::errors::GhashError;
use super::utils::block_mult;
use super::{master::GhashMaster, slave::GhashSlave, MasterCore, SlaveCore};
use ghash_rc::{
universal_hash::{NewUniversalHash, UniversalHash},
GHash,
};
use rand::prelude::ThreadRng;
use rand::{thread_rng, Rng};
use std::convert::TryInto;
#[test]
// test only round 1
fn test_round1() {
let block_count = 3;
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
run_round(&mut slave, &mut master).unwrap();
let ghash = finalize(&mut slave, &mut master);
assert_eq!(ghash, rust_crypto_ghash(h, &blocks));
assert!(master.is_complete());
assert!(slave.is_complete());
}
#[test]
// test state after rounds 1,2 (but before block aggregation)
fn test_round12_before_block_aggregation() {
let block_count = 30;
let (h, mut slave, mut master, _blocks) = ghash_setup(block_count);
run_round(&mut slave, &mut master).unwrap();
run_round(&mut slave, &mut master).unwrap();
let s_powers = slave.export_powers();
let r_powers = master.export_powers();
let all_s_keys: Vec<u16> = s_powers.keys().cloned().collect();
let all_r_keys: Vec<u16> = r_powers.keys().cloned().collect();
let expected_keys = vec![1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28];
let exp_powers = compute_expected_powers(h, block_count as u16);
assert_eq!(all_s_keys, expected_keys);
assert_eq!(all_r_keys, expected_keys);
// compare shares of powers against expected powers
for key in expected_keys.iter() {
assert_eq!(
exp_powers[*key as usize],
*s_powers.get(key).unwrap() ^ *r_powers.get(key).unwrap()
);
}
assert!(!master.is_complete());
assert!(!slave.is_complete());
}
#[test]
// test rounds 1,2,4
fn test_round124() {
let block_count = 30;
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
run_round(&mut slave, &mut master).unwrap();
run_round(&mut slave, &mut master).unwrap();
run_round(&mut slave, &mut master).unwrap();
let ghash = finalize(&mut slave, &mut master);
assert_eq!(ghash, rust_crypto_ghash(h, &blocks));
assert!(master.is_complete());
assert!(slave.is_complete());
}
#[test]
// test rounds 1,2,3,4
fn test_round1234() {
let block_count = 340;
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
run_round(&mut slave, &mut master).unwrap();
run_round(&mut slave, &mut master).unwrap();
run_round(&mut slave, &mut master).unwrap();
run_round(&mut slave, &mut master).unwrap();
let ghash = finalize(&mut slave, &mut master);
assert_eq!(ghash, rust_crypto_ghash(h, &blocks));
assert!(master.is_complete());
assert!(slave.is_complete());
}
#[test]
// test export_powers() after round 1
fn test_export_powers() {
let block_count = 340;
let (h, mut slave, mut master, blocks) = ghash_setup(block_count);
run_round(&mut slave, &mut master).unwrap();
let powers_s = slave.export_powers();
let powers_r = master.export_powers();
// we only have 4 consecutive powers 1,2,3,4 after round 1
// we compute ghash for only the first 4 blocks
let mut ghash_share_s = 0u128;
for i in 0..4 {
ghash_share_s ^= block_mult(*powers_s.get(&(4 - i)).unwrap(), blocks[i as usize]);
}
let mut ghash_share_r = 0u128;
for i in 0..4 {
ghash_share_r ^= block_mult(*powers_r.get(&(4 - i)).unwrap(), blocks[i as usize]);
}
let ghash = ghash_share_s ^ ghash_share_r;
assert_eq!(ghash, rust_crypto_ghash(h, &blocks[0..4].to_vec()));
assert!(!master.is_complete());
assert!(!slave.is_complete());
}
#[test]
// test OT count against hard-coded values and also double-check against
// the actual amount of Y bits which the Master requested.
fn test_calculate_ot_count() {
let block_counts = vec![3, 19, 200, 1026];
// expected amount of 2PC block multiplications
let expected = vec![2, 8, 44, 96];
for i in 0..block_counts.len() {
let block_count = block_counts[i];
let (_, mut slave, mut master, _) = ghash_setup(block_count);
let mut ybits_count = 0;
while !master.is_complete() {
let req = master.next_request().unwrap();
ybits_count += req.len();
let resp = slave.process_request().unwrap();
let masked_xtable = simulate_ot(&req, &resp);
master.process_response(&masked_xtable).unwrap();
}
let ot_count = master.calculate_ot_count();
assert!(ot_count == expected[i] * 128);
assert!(ot_count == ybits_count);
}
}
fn finalize(sender: &mut GhashSlave<ThreadRng>, receiver: &mut GhashMaster) -> u128 {
let sender_ghash_share = sender.finalize().unwrap();
let receiver_ghash_share = receiver.finalize().unwrap();
receiver_ghash_share ^ sender_ghash_share
}
fn ghash_setup(block_count: usize) -> (u128, GhashSlave<ThreadRng>, GhashMaster, Vec<u128>) {
let mut rng = thread_rng();
// h is ghash key
let h: u128 = rng.gen();
// h_s is sender's XOR share of h
let h_s: u128 = rng.gen();
// h_r is receiver's XOR share of h
let h_r: u128 = h ^ h_s;
let blocks: Vec<u128> = random_blocks(block_count);
let sender = GhashSlave::new(rng, h_s, blocks.clone()).unwrap();
let receiver = GhashMaster::new(h_r, blocks.clone()).unwrap();
(h, sender, receiver, blocks)
}
fn random_blocks(block_count: usize) -> Vec<u128> {
let mut rng = thread_rng();
let mut blocks: Vec<u128> = Vec::new();
for _i in 0..block_count {
blocks.push(rng.gen());
}
blocks
}
// compute GHASH using RustCrypto's ghash
fn rust_crypto_ghash(h: u128, blocks: &Vec<u128>) -> u128 {
let mut ghash = GHash::new(&h.to_be_bytes().into());
for block in blocks.iter() {
ghash.update(&block.to_be_bytes().into());
}
let b = ghash.finalize().into_bytes();
u128::from_be_bytes(b.as_slice().try_into().unwrap())
}
// prepare the expected powers of h by recursively multiplying h to
// itself
fn compute_expected_powers(h: u128, max: u16) -> Vec<u128> {
// prepare the expected powers of h by recursively multiplying h to
// itself
let mut powers: Vec<u128> = vec![0u128; (max + 1) as usize];
powers[1] = h;
let mut prev_power = h;
for i in 2..((max + 1) as usize) {
powers[i] = block_mult(prev_power, h);
prev_power = powers[i];
}
powers
}
// run_round runs the next round
fn run_round(
sender: &mut GhashSlave<ThreadRng>,
receiver: &mut GhashMaster,
) -> Result<(), GhashError> {
let receiver_bits = receiver.next_request()?;
let masked_xtable_full = sender.process_request()?;
let masked_xtable = simulate_ot(&receiver_bits, &masked_xtable_full);
receiver.process_response(&masked_xtable)?;
Ok(())
}
// normally Master will send his bits via OT to get only 1 out of 2 values
// for each row of masked xtable. Here we simulate this OT behaviour.
fn simulate_ot(receiver_bits: &Vec<bool>, mxtables_full: &Vec<[u128; 2]>) -> Vec<u128> {
assert!(receiver_bits.len() == mxtables_full.len());
let mut mxtables: Vec<u128> = Vec::new();
for i in 0..mxtables_full.len() {
let choice = receiver_bits[i] as usize;
mxtables.push(mxtables_full[i][choice]);
}
mxtables
}
}

183
tls-core/src/ghash/slave.rs Normal file
View File

@@ -0,0 +1,183 @@
//! Implements the GHASH Slave. This is the party which holds the X value of
//! block multiplication. Slave acts as the sender of the Oblivious
//! Transfer and sends masked x_table entries obliviously for each
//! bit of Y received from the GHASH Master.
use super::common::GhashCommon;
use super::errors::*;
use super::utils::{
block_aggregation, block_aggregation_mxtables, block_mult, free_square, masked_xtable,
multiply_powers_and_blocks, square_all,
};
use super::{MXTableFull, SlaveCore};
use rand::{CryptoRng, Rng};
use std::collections::BTreeMap;
#[derive(PartialEq)]
pub enum SlaveState {
Initialized,
// There may be from 1 to 4 rounds depending on GHASH block count
RoundReceived(usize),
Complete,
}
pub struct GhashSlave<R> {
c: GhashCommon,
rng: R,
state: SlaveState,
}
impl<R: Rng + CryptoRng> SlaveCore for GhashSlave<R> {
fn process_request(&mut self) -> Result<Vec<[u128; 2]>, GhashError> {
let retval;
let is_complete;
match self.state {
SlaveState::Initialized => {
self.state = SlaveState::RoundReceived(1);
retval = self.get_mxtables_for_round1().concat();
is_complete = self.is_only_1_round();
}
SlaveState::RoundReceived(1) => {
if self.is_round2_needed() {
self.state = SlaveState::RoundReceived(2);
retval = self.get_mxtables_for_round(2).concat();
is_complete = false;
} else {
// rounds 2 and 3 will be skipped
self.state = SlaveState::RoundReceived(4);
retval = self.get_mxtables_for_block_aggr().concat();
is_complete = true;
}
}
SlaveState::RoundReceived(2) => {
if self.is_round3_needed() {
self.state = SlaveState::RoundReceived(3);
retval = self.get_mxtables_for_round(3).concat();
is_complete = false;
} else {
// round 3 will be skipped
self.state = SlaveState::RoundReceived(4);
retval = self.get_mxtables_for_block_aggr().concat();
is_complete = true;
}
}
SlaveState::RoundReceived(3) => {
self.state = SlaveState::RoundReceived(4);
retval = self.get_mxtables_for_block_aggr().concat();
is_complete = true;
}
_ => {
return Err(GhashError::OutOfOrder);
}
}
if is_complete {
if self.state != SlaveState::RoundReceived(4) {
// if the last round was not round 4 (i.e. there was no block
// aggregation), then we compute GHASH directly
self.c.temp_share =
Some(multiply_powers_and_blocks(&self.c.powers, &self.c.blocks));
}
self.state = SlaveState::Complete;
}
Ok(retval)
}
fn finalize(&mut self) -> Result<u128, GhashError> {
if self.state != SlaveState::Complete {
return Err(GhashError::FinalizeCalledTooEarly);
}
Ok(self.c.temp_share.unwrap())
}
fn is_complete(&mut self) -> bool {
self.state == SlaveState::Complete
}
fn export_powers(&mut self) -> BTreeMap<u16, u128> {
self.c.export_powers()
}
fn calculate_ot_count(&mut self) -> usize {
self.c.calculate_ot_count()
}
}
impl<R: Rng + CryptoRng> GhashSlave<R> {
pub fn new(rng: R, ghash_key_share: u128, blocks: Vec<u128>) -> Result<Self, GhashError> {
let c = GhashCommon::new(ghash_key_share, blocks)?;
Ok(Self {
c,
rng,
state: SlaveState::Initialized,
})
}
fn is_round2_needed(&self) -> bool {
self.c.is_round2_needed()
}
fn is_round3_needed(&self) -> bool {
self.c.is_round3_needed()
}
fn is_only_1_round(&self) -> bool {
self.c.is_only_1_round()
}
/// Returns the masked X table for round 1.
/// Since the Master (M) sends bits for his powers in ascending order, we need to
/// accomodate that order, i.e. if we need to multiply M's H^1 by
/// our H^2 and then multiply M's H^2 by our H^1, then we return [mxtable
/// for H^2 + mxtable for H^1].
fn get_mxtables_for_round1(&mut self) -> Vec<MXTableFull> {
self.c.powers.insert(2, free_square(self.c.powers[&1]));
let (masked1, h3_share1) = masked_xtable(&mut self.rng, self.c.powers[&1]);
let (masked2, h3_share2) = masked_xtable(&mut self.rng, self.c.powers[&2]);
self.c.powers.insert(
3,
block_mult(self.c.powers[&1], self.c.powers[&2]) ^ h3_share1 ^ h3_share2,
);
// since we just added a new share of powers, we need to square them
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
vec![masked2, masked1]
}
/// Returns masked X tables for either round 2 or round 3.
fn get_mxtables_for_round(&mut self, round_no: u8) -> Vec<MXTableFull> {
assert!(round_no == 2 || round_no == 3);
let mut all_mxtables: Vec<MXTableFull> = Vec::new();
for (key, value) in self.c.strategies[(round_no - 2) as usize].clone().iter() {
if *key > self.c.max_odd_power {
break;
}
// Since Master sends bits in ascending order: factor1 bits,
// factor2 bits, we must return mxtables in descending order:
// factor2 mxtable, factor1 mxtable.
let factor1 = self.c.powers[&(value[0] as u16)];
let factor2 = self.c.powers[&(value[1] as u16)];
let (mxtable1, sum1) = masked_xtable(&mut self.rng, factor1);
let (mxtable2, sum2) = masked_xtable(&mut self.rng, factor2);
all_mxtables.push(mxtable2);
all_mxtables.push(mxtable1);
// our share of power <key> is the locally computed term plus sums
// of masks of each cross-term.
let local_term = block_mult(factor1, factor2);
self.c.powers.insert(*key as u16, local_term ^ sum1 ^ sum2);
}
// since we just added a few new shares of powers, we need to square them
self.c.powers = square_all(&self.c.powers, self.c.blocks.len() as u16);
all_mxtables
}
/// Returns masked X tables for the block aggregation method.
fn get_mxtables_for_block_aggr(&mut self) -> Vec<MXTableFull> {
let share1 = multiply_powers_and_blocks(&self.c.powers, &self.c.blocks);
let (aggregated, share2) = block_aggregation(&self.c.powers, &self.c.blocks);
let (mxtables, share3) =
block_aggregation_mxtables(&mut self.rng, &self.c.powers, &aggregated);
self.c.temp_share = Some(share1 ^ share2 ^ share3);
mxtables
}
}

587
tls-core/src/ghash/utils.rs Normal file
View File

@@ -0,0 +1,587 @@
use crate::ghash::{MXTableFull, YBits};
use mpc_core::utils::u8vec_to_boolvec;
use rand::{CryptoRng, Rng};
use std::collections::BTreeMap;
/// R is GCM polynomial in little-endian. In hex: "E1000000000000000000000000000000"
const R: u128 = 299076299051606071403356588563077529600;
/// Galois field multiplication of two 128-bit blocks reduced by the GCM polynomial
pub fn block_mult(mut x: u128, y: u128) -> u128 {
let mut result: u128 = 0;
for i in (0..128).rev() {
result ^= x * ((y >> i) & 1);
x = (x >> 1) ^ ((x & 1) * R);
}
result
}
/// Returns the squared value. It is called "free" because due to the
/// property of Galois multiplication, squaring can be done locally without
/// the need for 2PC.
pub fn free_square(x: u128) -> u128 {
block_mult(x, x)
}
/// Performs squaring of each odd power in "powers" up to and including
/// the maximum power "max" and returns an updated map of powers. Squaring
/// will be done recursively if needed, e.g if we have power == 1 and "max" is 22,
/// then 1 will be squared to get power == 2, then 2 -> 4, 4 -> 8, 8 -> 16.
/// Those powers which have already been squared will be skipped.
pub fn square_all(powers: &BTreeMap<u16, u128>, max: u16) -> BTreeMap<u16, u128> {
let mut new_powers: BTreeMap<u16, u128> = BTreeMap::new();
for (power, value) in powers.iter() {
// The fact the we had earlier computed more powers that we will ever
// need is a sign of a logic error which needs to be investigated.
assert!(*power <= max);
new_powers.insert(*power, *value);
if power % 2 == 0 {
continue;
}
// existing_power is the power for which we have the value.
let mut existing_power = *power;
while existing_power * 2 <= max {
// check if the squaring has already been done, otherwise do it now.
let option = powers.get(&(existing_power * 2));
let squared_value: u128;
if option == None {
let value_to_square = new_powers.get(&existing_power).unwrap();
squared_value = free_square(*value_to_square);
} else {
squared_value = *option.unwrap();
}
new_powers.insert(existing_power * 2, squared_value);
existing_power *= 2;
}
}
new_powers
}
/// Finds 2 non-equal summands which add up to the needed sum. The
/// first returned summand will be as small as possible.
/// E.g if "summands" keys are 1,2,3,5,6 and "sum_needed" is 8, then
/// the returned value would be (2,6).
pub fn find_sum(summands: &BTreeMap<u16, u128>, sum_needed: u16) -> (u16, u16) {
for (i, _) in summands.iter() {
for (j, _) in summands.iter() {
if *j == *i {
continue;
}
if *i + *j == sum_needed {
return (*i, *j);
}
}
}
// Should never get here. We only call find_sum when we know in advance
// that summands will be found.
panic!("summands were not found")
}
/// Returns the maximum odd power that we'll need to compute GHASH in 2PC
/// using Block Aggregation, where "max" is maximum power for GHASH
/// (i.e it is the amount of GHASH blocks).
pub fn find_max_odd_power(max: u16) -> u8 {
assert!(max <= 1026);
// max_htable's <value> shows how many GHASH blocks can be processed
// with Block Aggregation if we have all the sequential shares
// starting with 1 up to and including <key>.
// e.g. (5, 29) means that if we have shares of H^1,H^2,H^3,H^4,H^5,
// then we can process 29 GHASH blocks.
// max TLS record size of 16KB requires 1026 GHASH blocks
let max_htable: BTreeMap<u8, u16> = BTreeMap::from([
(0, 0),
(3, 19),
(5, 29),
(7, 71),
(9, 89),
(11, 107),
(13, 125),
(15, 271),
(17, 305),
(19, 339),
(21, 373),
(23, 407),
(25, 441),
(27, 475),
(29, 509),
(31, 1023),
(33, 1025),
(35, 1027),
]);
let mut out = 0u8;
for (key, value) in max_htable.iter() {
if *value >= max {
out = *key;
break;
}
}
out
}
/// Multiplies GHASH blocks by the corresponding shares of powers of H and
/// returns the sum of all products. If some share is not present, the
/// corresponding block is not multiplied at this stage but it will later
/// participate in block aggregation.
pub fn multiply_powers_and_blocks(powers: &BTreeMap<u16, u128>, blocks: &Vec<u128>) -> u128 {
let last_key = *powers.iter().last().unwrap().0;
assert!(last_key as usize <= blocks.len());
let mut sum = 0u128;
for (power, value) in powers.iter() {
// in GHASH, H^1 is multiplied with the last block, H^2 with the second to last
// block, etc.
sum ^= block_mult(*value, blocks[blocks.len() - (*power as usize)]);
}
sum
}
/// Implements the block aggregation method.
pub fn block_aggregation(
powers: &BTreeMap<u16, u128>,
blocks: &Vec<u128>,
) -> (BTreeMap<u16, u128>, u128) {
let mut ghash_share = 0u128;
let mut aggregated: BTreeMap<u16, u128> = BTreeMap::new();
for i in 1..blocks.len() + 1 {
if powers.get(&(i as u16)) != None {
// we already multiplied the block with this share of power in
// multiply_powers_and_blocks()
continue;
}
// else we found a power of H which we don't have.
let (small, big) = find_sum(&powers, i as u16);
let block = blocks[blocks.len() - i];
ghash_share ^= block_mult(
block_mult(*powers.get(&small).unwrap(), *powers.get(&big).unwrap()),
block,
);
// initialize the value if it doesn't exist
if aggregated.get(&small) == None {
aggregated.insert(small, 0u128);
}
// update value
let old_value = *aggregated.get(&small).unwrap();
aggregated.insert(
small,
old_value ^ block_mult(*powers.get(&big).unwrap(), block),
);
}
(aggregated, ghash_share)
}
/// Returns YBits which Master needs to complete Block Aggregation.
pub fn block_aggregation_bits(
powers: &BTreeMap<u16, u128>,
aggregated: &BTreeMap<u16, u128>,
) -> Vec<YBits> {
let mut all_bits: Vec<YBits> = Vec::new();
for (power, value) in aggregated.iter() {
// Master sends first bits of power then bits of value. Slave sends
// masked x tables in reverse order.
all_bits.push(u8vec_to_boolvec(
&(*powers.get(power).unwrap()).to_be_bytes(),
));
all_bits.push(u8vec_to_boolvec(&value.to_be_bytes()));
}
all_bits
}
/// Returns masked X tables which Slave needs to complete Block Aggregation.
pub fn block_aggregation_mxtables<R: Rng + CryptoRng>(
rng: &mut R,
powers: &BTreeMap<u16, u128>,
aggregated: &BTreeMap<u16, u128>,
) -> (Vec<MXTableFull>, u128) {
let mut all_mxtables: Vec<MXTableFull> = Vec::new();
let mut sum = 0u128;
for (power, value) in aggregated.iter() {
// Slave sends first masked x table of agregated value then masked x
// table of power value.
let (mxtable1, sum1) = masked_xtable(rng, *value);
let (mxtable2, sum2) = masked_xtable(rng, *powers.get(power).unwrap());
sum ^= sum1 ^ sum2;
all_mxtables.push(mxtable1);
all_mxtables.push(mxtable2);
}
(all_mxtables, sum)
}
/// Returns a table of values of x after each of the 128 rounds of blockMult()
fn xtable(mut x: u128) -> Vec<u128> {
let mut x_table: Vec<u128> = vec![0u128; 128];
for i in 0..128 {
x_table[i] = x;
x = (x >> 1) ^ ((x & 1) * R);
}
x_table
}
/// Returns:
/// 1) a masked xTable from which OT response will be constructed and
/// 2) the XOR-sum of all masks which is our share of the block multiplication product
/// For each value of xTable, the masked xTable will contain 2 values:
/// 1) a random mask and
/// 2) the xTable entry masked with the random mask.
pub fn masked_xtable<R: Rng + CryptoRng>(rng: &mut R, x: u128) -> (MXTableFull, u128) {
let x_table = xtable(x);
// maskSum is the xor sum of all masks
let mut mask_sum: u128 = 0;
let mut masked_xtable: MXTableFull = vec![[0u128; 2]; 128];
for i in 0..128 {
let mask: u128 = rng.gen();
mask_sum ^= mask;
masked_xtable[i][0] = mask;
masked_xtable[i][1] = x_table[i] ^ mask;
}
(masked_xtable, mask_sum)
}
/// Returns the XOR sum of all elements of the vector.
pub fn xor_sum(vec: &Vec<u128>) -> u128 {
vec.iter().fold(0u128, |acc, x| acc ^ x)
}
/// Converts a flat vector into a vector of chunks of the needed size.
pub fn flat_to_chunks<T>(flat: &Vec<T>, chunk_size: usize) -> Vec<Vec<T>>
where
T: Clone,
{
let count = flat.len() / chunk_size;
let mut vec_chunks: Vec<Vec<T>> = Vec::with_capacity(count);
for chunk in flat.chunks(chunk_size) {
vec_chunks.push(chunk.to_vec());
}
vec_chunks
}
#[cfg(test)]
mod tests {
use super::*;
use ghash::{
universal_hash::{NewUniversalHash, UniversalHash},
GHash,
};
use rand::SeedableRng;
use rand::{thread_rng, Rng};
use rand_chacha::ChaCha12Rng;
use std::convert::TryInto;
#[test]
fn test_block_mult() {
let mut rng = thread_rng();
let x: u128 = rng.gen();
let y: u128 = rng.gen();
assert_eq!(block_mult(x, y), rust_crypto_ghash(x, &vec![y]));
}
#[test]
fn test_free_square() {
let mut rng = thread_rng();
let x: u128 = rng.gen();
assert_eq!(free_square(x), rust_crypto_ghash(x, &vec![x]));
}
#[test]
fn test_square_all() {
let mut powers_keys: Vec<u16> = vec![1, 3, 5, 6];
let mut max_power = 8;
let mut powers = setup_square_all(powers_keys);
let mut res = square_all(&powers, max_power);
let mut res_keys: Vec<u16> = res.keys().cloned().collect();
let mut res_values: Vec<u128> = res.values().cloned().collect();
assert_eq!(res_keys, vec![1, 2, 3, 4, 5, 6, 8]);
assert_eq!(
res_values,
vec![
12346,
37384537381450758925419573570612497787,
12348,
330354857586696702251049094163216396600,
12350,
12351,
226635212255694396134298606764692048245
]
);
powers_keys = vec![1, 101];
max_power = 255;
powers = setup_square_all(powers_keys);
res = square_all(&powers, max_power);
res_keys = res.keys().cloned().collect();
res_values = res.values().cloned().collect();
assert_eq!(res_keys, vec![1, 2, 4, 8, 16, 32, 64, 101, 128, 202]);
assert_eq!(
res_values,
vec![
12346,
37384537381450758925419573570612497787,
330354857586696702251049094163216396600,
226635212255694396134298606764692048245,
57435732663173249101181542372863021389,
221430791223604693286277902449016542898,
102715545892785352441374614451375486130,
12446,
308951395680463668816102109032313818778,
144387391042136486694176041923180290643
]
);
}
#[test]
#[should_panic]
fn test_square_all_should_panic() {
let powers_keys: Vec<u16> = vec![1, 3, 5, 6, 9];
let max_power = 8;
let powers = setup_square_all(powers_keys);
// will panick because we have power 9 when max needed is 8
square_all(&powers, max_power);
}
#[test]
fn test_find_sum() {
let summands = setup_find_sum();
assert_eq!(find_sum(&summands, 8), (3, 5));
assert_eq!(find_sum(&summands, 14), (5, 9));
assert_eq!(find_sum(&summands, 21), (6, 15));
}
#[test]
#[should_panic]
fn test_find_sum_should_panic() {
let summands = setup_find_sum();
// will panick because summands must be non-equal, so (15, 15) is not
// allowed
find_sum(&summands, 30);
// will panick because no two summands add up to 25
find_sum(&summands, 25);
}
#[test]
fn test_find_max_odd_power() {
assert_eq!(find_max_odd_power(1), 3);
assert_eq!(find_max_odd_power(20), 5);
assert_eq!(find_max_odd_power(100), 11);
assert_eq!(find_max_odd_power(1000), 31);
}
#[test]
fn test_multiply_powers_and_blocks() {
let mut rng = thread_rng();
let block_count = 10;
let h: u128 = rng.gen();
let powers_of_h = compute_expected_powers(h, 10);
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
let mut blocks: Vec<u128> = Vec::new();
// init all powers and all blocks
for i in 0..block_count {
powers.insert(i + 1, powers_of_h[(i + 1) as usize]);
blocks.push(rng.gen());
}
let result = multiply_powers_and_blocks(&powers, &blocks);
assert_eq!(result, rust_crypto_ghash(h, &blocks));
}
#[test]
#[should_panic]
fn test_multiply_powers_and_blocks_should_panic() {
let mut rng = thread_rng();
let block_count = 10;
let h: u128 = rng.gen();
let powers_of_h = compute_expected_powers(h, block_count);
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
let mut blocks: Vec<u128> = Vec::new();
// init all powers and all blocks
for i in 0..block_count {
powers.insert(i + 1, powers_of_h[(i + 1) as usize]);
blocks.push(rng.gen());
}
// insert an extra power to have more powers than blocks which is a
// sign of a logic error.
powers.insert(
block_count + 1,
block_mult(powers_of_h[block_count as usize], h),
);
multiply_powers_and_blocks(&powers, &blocks);
}
#[test]
fn test_block_aggregation() {
let h: u128 = 123456;
let block_count = 10;
let powers_of_h = compute_expected_powers(h, block_count);
let mut powers_map: BTreeMap<u16, u128> = BTreeMap::new();
let mut blocks: Vec<u128> = Vec::with_capacity(block_count as usize);
for i in 1..block_count + 1 {
powers_map.insert(i, powers_of_h[i as usize]);
blocks.push(1234567 + i as u128);
}
// all powers are in place, so no block aggregation will happen
assert_eq!(
block_aggregation(&powers_map, &blocks),
(BTreeMap::new(), 0)
);
// remove some powers
powers_map.remove(&5);
powers_map.remove(&7);
let mut expected_map: BTreeMap<u16, u128> = BTreeMap::new();
expected_map.insert(1, 6529972824624832318862907648013721286);
assert_eq!(
block_aggregation(&powers_map, &blocks),
(expected_map, 315833047958356732231847338615588728787)
);
}
#[test]
fn test_block_aggregation_bits() {
let mut powers_map: BTreeMap<u16, u128> = BTreeMap::new();
let mut aggregated_map: BTreeMap<u16, u128> = BTreeMap::new();
powers_map.insert(1, 256);
aggregated_map.insert(1, 512);
let mut expected: [[bool; 128]; 2] = [[false; 128]; 2];
expected[0][128 - 9] = true; // set bit for 256
expected[1][128 - 10] = true; // set bit for 512
assert_eq!(
expected.concat(),
block_aggregation_bits(&powers_map, &aggregated_map).concat()
);
}
#[test]
fn test_block_aggregation_mxtables() {
let mut rng = ChaCha12Rng::seed_from_u64(12345);
let mut powers_map: BTreeMap<u16, u128> = BTreeMap::new();
let mut aggregated_map: BTreeMap<u16, u128> = BTreeMap::new();
powers_map.insert(1, 256);
aggregated_map.insert(1, 512);
let mut expected: [[bool; 128]; 2] = [[false; 128]; 2];
expected[0][128 - 9] = true; // set bit for 256
expected[1][128 - 10] = true; // set bit for 512
let (mxtables, share) = block_aggregation_mxtables(&mut rng, &powers_map, &aggregated_map);
assert_eq!(share, 119435139769675579125133100514879089925);
// since mxtables output is huge, we check only a few arbitrary elements of it
assert_eq!(
mxtables.concat()[23],
[
87470858790173581075227934021272140429,
87445059565157736150448673117636328077
]
);
assert_eq!(
mxtables.concat()[54],
[
332311817981764475918232627385210634271,
332311817981747626514621748491088166943
]
);
assert_eq!(
mxtables.concat()[10],
[
111226209635646018502889366408372405024,
237502869235213026428751037135005139744
]
);
}
#[test]
fn test_xtable() {
let result = xtable(123456u128);
// since xtable output is huge, we check only a few arbitrary elements of it
assert_eq!(result[0], 123456);
assert_eq!(result[39], 57076772936301564299567746252800);
assert_eq!(result[111], 12086476800);
}
#[test]
fn test_masked_xtable() {
let mut rng = thread_rng();
let x: u128 = rng.gen();
let y: u128 = rng.gen();
let expected = block_mult(x, y);
assert_eq!(expected, product_from_shares(x, y));
// corrupt some bytes of y value
let mut bad_bytes = y.to_be_bytes();
bad_bytes[5] = (bad_bytes[5] + 1) / 255;
bad_bytes[10] = (bad_bytes[10] + 1) / 255;
bad_bytes[15] = (bad_bytes[15] + 1) / 255;
let bad = u128::from_be_bytes(bad_bytes);
assert_ne!(expected, product_from_shares(x, bad));
}
#[test]
fn test_xor_sum() {
let mut rng = thread_rng();
let mut summands: Vec<u128> = Vec::new();
for _i in 0..300 {
let rand = rng.gen();
summands.push(rand);
summands.push(rand);
}
// xoring the same value twice should result in zero
assert_eq!(xor_sum(&summands), 0);
summands.push(123456);
assert_eq!(xor_sum(&summands), 123456);
}
// compute GHASH using RustCrypto's ghash
fn rust_crypto_ghash(h: u128, blocks: &Vec<u128>) -> u128 {
let mut ghash = GHash::new(&h.to_be_bytes().into());
for block in blocks.iter() {
ghash.update(&block.to_be_bytes().into());
}
let b = ghash.finalize().into_bytes();
u128::from_be_bytes(b.as_slice().try_into().unwrap())
}
// prepare the expected powers of h by recursively multiplying h to
// itself
fn compute_expected_powers(h: u128, max: u16) -> Vec<u128> {
// prepare the expected powers of h by recursively multiplying h to
// itself
let mut powers: Vec<u128> = vec![0u128; (max + 1) as usize];
powers[1] = h;
let mut prev_power = h;
for i in 2..((max + 1) as usize) {
powers[i] = block_mult(prev_power, h);
prev_power = powers[i];
}
powers
}
fn setup_find_sum() -> BTreeMap<u16, u128> {
let summands_keys: Vec<u16> = vec![1, 3, 5, 6, 8, 9, 12, 15];
let mut summands: BTreeMap<u16, u128> = BTreeMap::new();
// assign any value > 0 to elements at keys corresponding to
// summands_keys
for v in summands_keys.iter() {
summands.insert(*v, (12345 + *v) as u128);
}
summands
}
fn setup_square_all(keys: Vec<u16>) -> BTreeMap<u16, u128> {
let mut powers: BTreeMap<u16, u128> = BTreeMap::new();
// assign any value to elements at keys corresponding to
// powers_keys
for v in keys.iter() {
powers.insert(*v, (12345 + *v) as u128);
}
powers
}
fn product_from_shares(x: u128, y: u128) -> u128 {
// instantiate with empty values, we only need rng for this test
let (masked_xtable, my_product_share) = masked_xtable(&mut thread_rng(), x);
// the other party who has the y value will receive only 1 value (out of 2)
// for each entry in maskedXTable via Oblivious Transfer depending on the
// bits of y. We simulate that here:
let mut his_product_share = 0u128;
let bits = u8vec_to_boolvec(&y.to_be_bytes());
for i in 0..128 {
// the first element in xTable corresponds to the highest bit of y
his_product_share ^= masked_xtable[i][bits[i] as usize];
}
my_product_share ^ his_product_share
}
}

View File

@@ -1,2 +1,4 @@
#[cfg(feature = "ghash")]
pub mod ghash;
#[cfg(feature = "prf")]
pub mod prf;