diff --git a/.gitignore b/.gitignore index 6452652ce..412bb21c6 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ Cargo.lock .vscode/ tlsn-mpc-circuits/compiled/* + +# ignore compiled circuits +**/bin/*.bin \ No newline at end of file diff --git a/mpc/mpc-core/src/garble/label/mod.rs b/mpc/mpc-core/src/garble/label/mod.rs index 19e014e13..37de5b7b7 100644 --- a/mpc/mpc-core/src/garble/label/mod.rs +++ b/mpc/mpc-core/src/garble/label/mod.rs @@ -63,7 +63,7 @@ pub struct Delta(Block); impl Delta { /// Creates new random Delta - pub(crate) fn random(rng: &mut R) -> Self { + pub fn random(rng: &mut R) -> Self { let mut block = Block::random(rng); block.set_lsb(); Self(block) diff --git a/prf/Cargo.toml b/prf/Cargo.toml new file mode 100644 index 000000000..d890693d6 --- /dev/null +++ b/prf/Cargo.toml @@ -0,0 +1,48 @@ +[workspace] +members = [ + "hmac-sha256-utils", + "hmac-sha256-circuits", + "hmac-sha256-core", + "hmac-sha256", +] + +[workspace.dependencies] +# tlsn +tlsn-mpc-circuits = { path = "../mpc/mpc-circuits" } +tlsn-mpc-core = { path = "../mpc/mpc-core" } +tlsn-mpc-aio = { path = "../mpc/mpc-aio" } +tlsn-utils = { path = "../utils/utils" } +tlsn-utils-aio = { path = "../utils/utils-aio" } + +# rand +rand = "0.8" +rand_chacha = "0.3" + +# crypto +sha2 = "0.10" +hmac = "0.12" +digest = "0.10" + +# async +async-trait = "0.1" +futures = "0.3" +futures-util = "0.3" +tokio = "1.23" +tokio-util = "0.7" + +# serialization +serde = "1.0" +prost = "0.9" +prost-build = "0.9" + +# error/log +thiserror = "1" + +# testing +criterion = "0.3" + +# misc +derive_builder = "0.11" +once_cell = "1" +generic-array = "0.14" +rayon = "1" diff --git a/prf/hmac-sha256-circuits/Cargo.toml b/prf/hmac-sha256-circuits/Cargo.toml new file mode 100644 index 000000000..d860540b7 --- /dev/null +++ b/prf/hmac-sha256-circuits/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tlsn-hmac-sha256-circuits" +version = "0.1.0" +edition = "2021" + +[lib] +name = "hmac_sha256_circuits" + +[features] +default = [] + +[dependencies] +tlsn-hmac-sha256-utils = { path = "../hmac-sha256-utils" } +tlsn-mpc-circuits.workspace = true +tlsn-utils.workspace = true + +[dev-dependencies] +generic-array.workspace = true +rand.workspace = true +rand_chacha.workspace = true +sha2 = { workspace = true, features = ["compress"] } diff --git a/prf/hmac-sha256-circuits/src/hmac_sha256.rs b/prf/hmac-sha256-circuits/src/hmac_sha256.rs new file mode 100644 index 000000000..f0ab6fd08 --- /dev/null +++ b/prf/hmac-sha256-circuits/src/hmac_sha256.rs @@ -0,0 +1,295 @@ +use std::sync::Arc; + +use mpc_circuits::{ + builder::{map_bytes, CircuitBuilder, Feed, Gates, WireHandle}, + circuits::nbit_xor, + BitOrder, Circuit, ValueType, +}; +use utils::bits::IterToBits; + +use crate::{add_sha256_compress, add_sha256_finalize}; +use hmac_sha256_utils::SHA256_INITIAL_STATE; + +/// Computes the outer and inner states of HMAC-SHA256. +/// +/// Outer state is H(key ⊕ opad) +/// +/// Inner state is H(key ⊕ ipad) +/// +/// # Arguments +/// +/// * `builder` - Mutable reference to the circuit builder +/// * `key` - N-byte key (must be <= 64 bytes) +/// * `const_zero` - 1-bit constant zero +/// * `const_one` - 1-bit constant one +/// +/// # Returns +/// +/// * `outer_state` - 256-bit outer state +/// * `inner_state` - 256-bit inner state +pub fn add_hmac_sha256_partial( + builder: &mut CircuitBuilder, + key: &[WireHandle], + const_zero: &WireHandle, + const_one: &WireHandle, +) -> (Vec>, Vec>) { + let xor_circ = nbit_xor(512); + + let xor_opad = builder.add_circ(&xor_circ); + let xor_ipad = builder.add_circ(&xor_circ); + + let key_opad = { + let a = xor_opad.input(0).expect("xor should have input 0"); + let b = xor_opad.input(1).expect("xor should have input 1"); + + // Connect key wires + builder.connect(key, &a[..key.len()]); + // Connect zero pads + builder.connect_fan_out(*const_zero, &a[key.len()..]); + + // Connect opad wires + map_bytes( + builder, + BitOrder::Msb0, + *const_zero, + *const_one, + &b[..], + &[0x5cu8; 64], + ); + + xor_opad.output(0).expect("xor should have output 0") + }; + + let key_ipad = { + let a = xor_ipad.input(0).expect("xor should have input 0"); + let b = xor_ipad.input(1).expect("xor should have input 1"); + + // Connect key wires + builder.connect(key, &a[..key.len()]); + // Connect zero pads + builder.connect_fan_out(*const_zero, &a[key.len()..]); + + // Connect ipad wires + map_bytes( + builder, + BitOrder::Msb0, + *const_zero, + *const_one, + &b[..], + &[0x36; 64], + ); + + xor_ipad.output(0).expect("xor should have output 0") + }; + + let sha256_initial_state = SHA256_INITIAL_STATE + .into_msb0_iter() + .map(|bit| if bit { *const_one } else { *const_zero }) + .collect::>(); + + let outer_state = add_sha256_compress(builder, &key_opad[..], &sha256_initial_state); + let inner_state = add_sha256_compress(builder, &key_ipad[..], &sha256_initial_state); + + (outer_state, inner_state) +} + +/// Computes HMAC(k, m) using existing key hash states. +/// +/// # Inputs +/// +/// * `builder` - Mutable reference to the circuit builder +/// * `outer_state` - 256-bit outer hash state +/// * `inner_state` - 256-bit inner hash state +/// * `msg` - Arbitrary length message +/// * `const_zero` - 1-bit constant zero +/// * `const_one` - 1-bit constant one +/// +/// # Returns +/// +/// * `hash` - 256-bit HMAC-SHA256 hash +pub fn add_hmac_sha256_finalize( + builder: &mut CircuitBuilder, + outer_state: &[WireHandle], + inner_state: &[WireHandle], + msg: &[WireHandle], + const_zero: &WireHandle, + const_one: &WireHandle, +) -> Vec> { + let inner_hash = add_sha256_finalize(builder, msg, inner_state, const_zero, const_one, 64); + let outer_hash = + add_sha256_finalize(builder, &inner_hash, outer_state, const_zero, const_one, 64); + + outer_hash +} + +/// Computes HMAC(k, m) using existing key hash states. +/// +/// # Inputs +/// +/// 0. OUTER_STATE: 32-byte outer hash state +/// 1. INNER_STATE: 32-byte inner hash state +/// 2. MSG: N-byte message +/// +/// # Outputs +/// +/// 0. HASH: 32-byte hash +pub fn hmac_sha256_finalize(len: usize) -> Arc { + let mut builder = + CircuitBuilder::new(&format!("{len}byte_sha256"), "", "0.1.0", BitOrder::Msb0); + + let outer_state = builder.add_input( + "OUTER_STATE", + "32-byte outer hash state", + ValueType::Bytes, + 256, + ); + let inner_state = builder.add_input( + "INNER_STATE", + "32-byte inner hash state", + ValueType::Bytes, + 256, + ); + let msg = builder.add_input( + "MSG", + &format!("{len}-byte message"), + ValueType::Bytes, + len * 8, + ); + let const_zero = builder.add_input( + "const_zero", + "input that is always 0", + ValueType::ConstZero, + 1, + ); + let const_one = builder.add_input( + "const_one", + "input that is always 1", + ValueType::ConstOne, + 1, + ); + + let mut builder = builder.build_inputs(); + + let hash = add_hmac_sha256_finalize( + &mut builder, + &outer_state[..], + &inner_state[..], + &msg[..], + &const_zero[0], + &const_one[0], + ); + + let mut builder = builder.build_gates(); + + let hash_output = builder.add_output("HASH", "32-byte hash", ValueType::Bytes, 256); + + builder.connect(&hash[..], &hash_output[..]); + + builder + .build_circuit() + .expect("failed to build hmac_sha256") +} + +#[cfg(test)] +mod tests { + use super::*; + + use hmac_sha256_utils::{hmac, partial_sha256_digest}; + use mpc_circuits::{circuits::test_circ, Value}; + + #[test] + #[ignore = "expensive"] + fn test_hmac_sha256_finalize() { + let key = [69u8; 32]; + let msg = [42u8; 47]; + + let circ = hmac_sha256_finalize(msg.len()); + + let key_opad = key + .iter() + .chain(&[0u8; 32]) + .map(|k| k ^ 0x5cu8) + .collect::>(); + + let key_ipad = key + .iter() + .chain(&[0u8; 32]) + .map(|k| k ^ 0x36u8) + .collect::>(); + + let outer_hash_state = partial_sha256_digest(&key_opad); + let inner_hash_state = partial_sha256_digest(&key_ipad); + + let expected = hmac(&key, &msg); + + test_circ( + &circ, + &[ + Value::Bytes( + outer_hash_state + .into_iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect(), + ), + Value::Bytes( + inner_hash_state + .into_iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect(), + ), + Value::Bytes(msg.to_vec()), + ], + &[Value::Bytes(expected)], + ); + } + + #[test] + #[ignore = "expensive"] + fn test_hmac_sha256_finalize_multi_block() { + let key = [69u8; 32]; + let msg = [42u8; 79]; + + let circ = hmac_sha256_finalize(msg.len()); + + let key_opad = key + .iter() + .chain(&[0u8; 32]) + .map(|k| k ^ 0x5cu8) + .collect::>(); + + let key_ipad = key + .iter() + .chain(&[0u8; 32]) + .map(|k| k ^ 0x36u8) + .collect::>(); + + let outer_hash_state = partial_sha256_digest(&key_opad); + let inner_hash_state = partial_sha256_digest(&key_ipad); + + let expected = hmac(&key, &msg); + + test_circ( + &circ, + &[ + Value::Bytes( + outer_hash_state + .into_iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect(), + ), + Value::Bytes( + inner_hash_state + .into_iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect(), + ), + Value::Bytes(msg.to_vec()), + ], + &[Value::Bytes(expected)], + ); + } +} diff --git a/prf/hmac-sha256-circuits/src/lib.rs b/prf/hmac-sha256-circuits/src/lib.rs new file mode 100644 index 000000000..c6028b62f --- /dev/null +++ b/prf/hmac-sha256-circuits/src/lib.rs @@ -0,0 +1,13 @@ +mod hmac_sha256; +mod master_secret; +mod prf; +mod session_keys; +mod sha256; +mod verify_data; + +pub use hmac_sha256::{add_hmac_sha256_finalize, add_hmac_sha256_partial, hmac_sha256_finalize}; +pub use master_secret::master_secret; +pub use prf::{add_prf, prf}; +pub use session_keys::session_keys; +pub use sha256::{add_sha256_compress, add_sha256_finalize, sha256}; +pub use verify_data::verify_data; diff --git a/prf/hmac-sha256-circuits/src/master_secret.rs b/prf/hmac-sha256-circuits/src/master_secret.rs new file mode 100644 index 000000000..81fcac353 --- /dev/null +++ b/prf/hmac-sha256-circuits/src/master_secret.rs @@ -0,0 +1,162 @@ +use std::sync::Arc; + +use mpc_circuits::{builder::CircuitBuilder, BitOrder, Circuit, ValueType}; +use utils::bits::IterToBits; + +use crate::{add_hmac_sha256_partial, add_prf}; + +/// Master secret +/// +/// Computes the master secret (MS), returning the outer and inner HMAC states. +/// +/// Outer state is H(master_secret ⊕ opad) +/// +/// Inner state is H(master_secret ⊕ ipad) +/// +/// Inputs: +/// +/// 0. PMS: 32-byte pre-master secret +/// 1. CLIENT_RAND: 32-byte client random +/// 2. SERVER_RAND: 32-byte server random +/// +/// Outputs: +/// +/// 0. OUTER_STATE: 32-byte HMAC outer hash state +/// 1. INNER_STATE: 32-byte HMAC inner hash state +pub fn master_secret() -> Arc { + let mut builder = CircuitBuilder::new("master_secret", "", "0.1.0", BitOrder::Msb0); + + let pms = builder.add_input("PMS", "32-byte PMS, big endian", ValueType::Bytes, 256); + let client_random = builder.add_input( + "CLIENT_RAND", + "32-byte client random", + ValueType::Bytes, + 256, + ); + let server_random = builder.add_input( + "SERVER_RAND", + "32-byte server random", + ValueType::Bytes, + 256, + ); + + let const_zero = builder.add_input( + "const_zero", + "input that is always 0", + ValueType::ConstZero, + 1, + ); + let const_one = builder.add_input( + "const_one", + "input that is always 1", + ValueType::ConstOne, + 1, + ); + + let mut builder = builder.build_inputs(); + + let (pms_outer_state, pms_inner_state) = + add_hmac_sha256_partial(&mut builder, &pms[..], &const_zero[0], &const_one[0]); + + let label = b"master secret" + .into_msb0_iter() + .map(|bit| if bit { const_one[0] } else { const_zero[0] }) + .collect::>(); + let seed = client_random[..] + .iter() + .chain(&server_random[..]) + .copied() + .collect::>(); + + let ms = add_prf( + &mut builder, + &pms_outer_state, + &pms_inner_state, + &const_zero[0], + &const_one[0], + &label, + &seed, + 48, + ); + + let (ms_outer_state, ms_inner_state) = + add_hmac_sha256_partial(&mut builder, &ms, &const_zero[0], &const_one[0]); + + let mut builder = builder.build_gates(); + + let outer_state = builder.add_output( + "OUTER_STATE", + "32-byte HMAC outer hash state", + ValueType::Bytes, + 256, + ); + + builder.connect(&ms_outer_state[..], &outer_state[..]); + + let inner_state = builder.add_output( + "INNER_STATE", + "32-byte HMAC inner hash state", + ValueType::Bytes, + 256, + ); + + builder.connect(&ms_inner_state[..], &inner_state[..]); + + builder + .build_circuit() + .expect("failed to build master_secret") +} + +#[cfg(test)] +mod tests { + use super::*; + + use hmac_sha256_utils::{partial_hmac, prf}; + use mpc_circuits::{circuits::test_circ, Value}; + + #[test] + #[ignore = "expensive"] + fn test_master_secret() { + let circ = master_secret(); + + println!("MS Circuit size: {}", circ.and_count()); + + let pms = [69u8; 32]; + let client_random = [1u8; 32]; + let server_random = [2u8; 32]; + + let seed = client_random + .iter() + .chain(&server_random) + .copied() + .collect::>(); + + let ms = prf(&pms, b"master secret", &seed, 48); + + let (expected_outer_state, expected_inner_state) = partial_hmac(&ms); + + let expected_outer_state = expected_outer_state + .into_iter() + .map(|v| v.to_be_bytes()) + .flatten() + .collect::>(); + let expected_inner_state = expected_inner_state + .into_iter() + .map(|v| v.to_be_bytes()) + .flatten() + .collect::>(); + + test_circ( + &circ, + &[ + Value::Bytes(pms.to_vec()), + Value::Bytes(client_random.to_vec()), + Value::Bytes(server_random.to_vec()), + ], + &[ + Value::Bytes(expected_outer_state), + Value::Bytes(expected_inner_state), + ], + ); + } +} diff --git a/prf/hmac-sha256-circuits/src/prf.rs b/prf/hmac-sha256-circuits/src/prf.rs new file mode 100644 index 000000000..fbdcff14c --- /dev/null +++ b/prf/hmac-sha256-circuits/src/prf.rs @@ -0,0 +1,234 @@ +//! This module provides an implementation of the HMAC-SHA256 PRF defined in [RFC 5246](https://www.rfc-editor.org/rfc/rfc5246#section-5). + +use std::sync::Arc; + +use mpc_circuits::{ + builder::{CircuitBuilder, Feed, Gates, WireHandle}, + BitOrder, Circuit, ValueType, +}; + +use utils::bits::IterToBits; + +use crate::{add_hmac_sha256_finalize, add_hmac_sha256_partial}; + +// P_hash(secret, seed) = +// HMAC_hash(secret, A(1) + seed) + +// HMAC_hash(secret, A(2) + seed) + +// HMAC_hash(secret, A(3) + seed) + ... +fn add_p_hash( + builder: &mut CircuitBuilder, + outer_state: &[WireHandle], + inner_state: &[WireHandle], + const_zero: &WireHandle, + const_one: &WireHandle, + seed: &[WireHandle], + iterations: usize, +) -> Vec> { + // A() is defined as: + // + // A(0) = seed + // A(i) = HMAC_hash(secret, A(i-1)) + let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1); + a_cache.push(seed.to_vec()); + + for i in 0..iterations { + let a_i = add_hmac_sha256_finalize( + builder, + outer_state, + inner_state, + &a_cache[i], + const_zero, + const_one, + ); + a_cache.push(a_i); + } + + // HMAC_hash(secret, A(i) + seed) + let mut output: Vec> = Vec::with_capacity(iterations * 32 * 8); + for i in 0..iterations { + let mut a_i_seed = a_cache[i + 1].clone(); + a_i_seed.extend_from_slice(seed); + + let hash = add_hmac_sha256_finalize( + builder, + outer_state, + inner_state, + &a_i_seed, + const_zero, + const_one, + ); + output.extend_from_slice(&hash); + } + + output +} + +/// Computes PRF(secret, label, seed) +/// +/// # Arguments +/// +/// * `builder` - Mutable reference to the circuit builder +/// * `outer_state` - The outer state of HMAC-SHA256 +/// * `inner_state` - The inner state of HMAC-SHA256 +/// * `const_zero` - A constant wire that is always 0 +/// * `const_one` - A constant wire that is always 1 +/// * `label` - The label to use +/// * `seed` - The seed to use +/// * `bytes` - The number of bytes to output +/// +/// # Returns +/// +/// * `prf_bytes` - `bytes` bytes of output +pub fn add_prf( + builder: &mut CircuitBuilder, + outer_state: &[WireHandle], + inner_state: &[WireHandle], + const_zero: &WireHandle, + const_one: &WireHandle, + label: &[WireHandle], + seed: &[WireHandle], + bytes: usize, +) -> Vec> { + let iterations = bytes / 32 + (bytes % 32 != 0) as usize; + + let mut label_seed = label.to_vec(); + label_seed.extend_from_slice(seed); + + let p_hash = add_p_hash( + builder, + outer_state, + inner_state, + const_zero, + const_one, + &label_seed, + iterations, + ); + + // Truncate to the desired number of bytes + let prf_bytes = p_hash[..bytes * 8].to_vec(); + + prf_bytes +} + +/// Computes PRF(key, seed) +/// +/// Inputs: +/// +/// 0. KEY: 32-byte key +/// 1. SEED: N-byte seed +/// +/// Outputs: +/// +/// 0. BYTES: N-byte output +/// +/// # Arguments +/// * `name` - The name of the circuit +/// * `description` - The description of the circuit +/// * `label` - The label to use +/// * `key_len` - The length of the key in bytes +/// * `seed_len` - The length of the seed in bytes +/// * `output_len` - The number of bytes to generate +pub fn prf( + name: &str, + description: &str, + label: &[u8], + key_len: usize, + seed_len: usize, + output_len: usize, +) -> Arc { + let mut builder = CircuitBuilder::new(name, description, "0.1.0", BitOrder::Msb0); + + let key = builder.add_input( + "KEY", + &format!("{key_len}-byte key"), + ValueType::Bytes, + key_len * 8, + ); + let seed = builder.add_input( + "SEED", + &format!("{seed_len}-byte seed"), + ValueType::Bytes, + seed_len * 8, + ); + let const_zero = builder.add_input( + "const_zero", + "input that is always 0", + ValueType::ConstZero, + 1, + ); + let const_one = builder.add_input( + "const_one", + "input that is always 1", + ValueType::ConstOne, + 1, + ); + + let mut builder = builder.build_inputs(); + + let label = label + .into_iter() + .copied() + .into_msb0_iter() + .map(|bit| if bit { const_one[0] } else { const_zero[0] }) + .collect::>(); + + let (outer_state, inner_state) = + add_hmac_sha256_partial(&mut builder, &key[..], &const_zero[0], &const_one[0]); + + let prf_bytes = add_prf( + &mut builder, + &outer_state, + &inner_state, + &const_zero[0], + &const_one[0], + &label, + &seed[..], + output_len, + ); + + let mut builder = builder.build_gates(); + + let bytes_out = builder.add_output( + "BYTES", + &format!("{output_len}-byte output"), + ValueType::Bytes, + output_len * 8, + ); + + builder.connect(&prf_bytes, &bytes_out[..]); + + builder.build_circuit().expect("failed to build prf") +} + +#[cfg(test)] +mod tests { + use super::*; + + use mpc_circuits::{circuits::test_circ, Value}; + + #[test] + #[ignore = "expensive"] + fn test_prf() { + let pms = [69u8; 32]; + let label = b"master secret"; + let client_random = [42u8; 32]; + let server_random = [69u8; 32]; + + let seed = { + let mut seed = Vec::new(); + seed.extend_from_slice(&client_random); + seed.extend_from_slice(&server_random); + seed + }; + + let circ = prf("ms", "", label, 32, 64, 48); + + let expected = hmac_sha256_utils::prf(&pms, label, &seed, 48); + + test_circ( + &circ, + &[Value::Bytes(pms.to_vec()), Value::Bytes(seed)], + &[Value::Bytes(expected)], + ); + } +} diff --git a/prf/hmac-sha256-circuits/src/session_keys.rs b/prf/hmac-sha256-circuits/src/session_keys.rs new file mode 100644 index 000000000..c275c48ca --- /dev/null +++ b/prf/hmac-sha256-circuits/src/session_keys.rs @@ -0,0 +1,172 @@ +use std::sync::Arc; + +use mpc_circuits::{builder::CircuitBuilder, BitOrder, Circuit, ValueType}; +use utils::bits::IterToBits; + +use crate::add_prf; + +/// Session Keys +/// +/// Compute expanded p1 which consists of client_write_key + server_write_key +/// Compute expanded p2 which consists of client_IV + server_IV +/// +/// Inputs: +/// +/// 0. OUTER_HASH_STATE: 32-byte MS outer-hash state +/// 1. INNER_HASH_STATE: 32-byte MS inner-hash state +/// 2. CLIENT_RAND: 32-byte client random +/// 3. SERVER_RAND: 32-byte server random +/// +/// Outputs: +/// +/// 0. CWK: 16-byte client write-key +/// 1. SWK: 16-byte server write-key +/// 2. CIV: 4-byte client IV +/// 3. SIV: 4-byte server IV +pub fn session_keys() -> Arc { + let mut builder = CircuitBuilder::new("session_keys", "", "0.1.0", BitOrder::Msb0); + + let outer_state = builder.add_input( + "OUTER_HASH_STATE", + "32-byte MS outer-hash state", + ValueType::Bytes, + 256, + ); + let inner_state = builder.add_input( + "INNER_HASH_STATE", + "32-byte MS inner-hash state", + ValueType::Bytes, + 256, + ); + let client_random = builder.add_input( + "CLIENT_RAND", + "32-byte client random", + ValueType::Bytes, + 256, + ); + let server_random = builder.add_input( + "SERVER_RAND", + "32-byte server random", + ValueType::Bytes, + 256, + ); + let const_zero = builder.add_input( + "const_zero", + "input that is always 0", + ValueType::ConstZero, + 1, + ); + let const_one = builder.add_input( + "const_one", + "input that is always 1", + ValueType::ConstOne, + 1, + ); + + let mut builder = builder.build_inputs(); + + let label = b"key expansion" + .into_msb0_iter() + .map(|bit| if bit { const_one[0] } else { const_zero[0] }) + .collect::>(); + let seed = server_random[..] + .iter() + .chain(&client_random[..]) + .copied() + .collect::>(); + + let key_material = add_prf( + &mut builder, + &outer_state[..], + &inner_state[..], + &const_zero[0], + &const_one[0], + &label, + &seed, + 40, + ); + + let mut builder = builder.build_gates(); + + let cwk = builder.add_output("CWK", "16-byte client write-key", ValueType::Bytes, 128); + let swk = builder.add_output("SWK", "16-byte server write-key", ValueType::Bytes, 128); + let civ = builder.add_output("CIV", "4-byte client IV", ValueType::Bytes, 32); + let siv = builder.add_output("SIV", "4-byte server IV", ValueType::Bytes, 32); + + builder.connect(&key_material[..128], &cwk[..]); + + builder.connect(&key_material[128..256], &swk[..]); + + builder.connect(&key_material[256..288], &civ[..]); + + builder.connect(&key_material[288..], &siv[..]); + + builder + .build_circuit() + .expect("failed to build session_keys") +} + +#[cfg(test)] +mod tests { + use super::*; + use mpc_circuits::{circuits::test_circ, Value}; + + #[test] + #[ignore = "expensive"] + fn test_session_keys() { + let circ = session_keys(); + + println!("KE Circuit size: {}", circ.and_count()); + + let ms = [42u8; 48]; + let client_random = [1u8; 32]; + let server_random = [2u8; 32]; + let seed = server_random + .iter() + .chain(&client_random) + .copied() + .collect::>(); + + let (outer_hash_state, inner_hash_state) = hmac_sha256_utils::partial_hmac(&ms); + + let key_material = hmac_sha256_utils::prf(&ms, b"key expansion", &seed, 40); + + // split into client/server_write_key and client/server_write_iv + let mut cwk = [0u8; 16]; + cwk.copy_from_slice(&key_material[0..16]); + let mut swk = [0u8; 16]; + swk.copy_from_slice(&key_material[16..32]); + let mut civ = [0u8; 4]; + civ.copy_from_slice(&key_material[32..36]); + let mut siv = [0u8; 4]; + siv.copy_from_slice(&key_material[36..40]); + + test_circ( + &circ, + &[ + Value::Bytes( + outer_hash_state + .into_iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect(), + ), + Value::Bytes( + inner_hash_state + .into_iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect(), + ), + Value::Bytes(client_random.to_vec()), + Value::Bytes(server_random.to_vec()), + ], + &[ + Value::Bytes(cwk.to_vec()), + Value::Bytes(swk.to_vec()), + Value::Bytes(civ.to_vec()), + Value::Bytes(siv.to_vec()), + ], + ); + } +} diff --git a/prf/hmac-sha256-circuits/src/sha256.rs b/prf/hmac-sha256-circuits/src/sha256.rs new file mode 100644 index 000000000..ee46a58f2 --- /dev/null +++ b/prf/hmac-sha256-circuits/src/sha256.rs @@ -0,0 +1,187 @@ +use std::sync::Arc; + +use mpc_circuits::{ + builder::{CircuitBuilder, Feed, Gates, WireHandle}, + BitOrder, Circuit, ValueType, SHA_256, +}; +use utils::bits::{IterToBits, ToBits}; + +use hmac_sha256_utils::SHA256_INITIAL_STATE; + +/// Computes SHA-256 compression function +/// +/// # Arguments +/// +/// * `builder` - Mutable reference to the circuit builder +/// * `msg` - 512-bit message +/// * `initial_state` - 256-bit initial SHA256 state +/// +/// # Returns +/// +/// * `output_state` - 256-bit output state +pub fn add_sha256_compress( + builder: &mut CircuitBuilder, + msg: &[WireHandle], + initial_state: &[WireHandle], +) -> Vec> { + let sha256 = builder.add_circ(&SHA_256); + + let msg_input = sha256.input(0).expect("sha256 missing input 0"); + let state_input = sha256.input(1).expect("sha256 missing input 1"); + + builder.connect(msg, &msg_input[..]); + builder.connect(initial_state, &state_input[..]); + + let output_state = sha256.output(0).expect("sha256 missing output 0")[..].to_vec(); + + output_state +} + +/// Computes SHA-256 hash of an arbitrary length message +/// +/// # Arguments +/// +/// * `builder` - Mutable reference to the circuit builder +/// * `msg` - Arbitrary length message +/// * `initial_state` - 256-bit initial SHA256 state +/// * `const_zero` - 1-bit constant zero +/// * `const_one` - 1-bit constant one +/// * `start_pos` - The number of bytes already processed in the initial state +/// +/// # Returns +/// +/// * `hash` - 256-bit SHA256 hash +pub fn add_sha256_finalize( + builder: &mut CircuitBuilder, + msg: &[WireHandle], + initial_state: &[WireHandle], + const_zero: &WireHandle, + const_one: &WireHandle, + start_pos: usize, +) -> Vec> { + // begin with the original message of length L bits + // append a single '1' bit + // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + 64) is a multiple of 512 + // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits + // such that the bits in the message are: 1 , (the number of bits will be a multiple of 512) + + let bit_len = msg.len(); + let processed_bit_len = (bit_len + (start_pos * 8)) as u64; + + // K length + let zero_pad_len = 512 - ((bit_len + 65) % 512); + + let mut padded_msg: Vec> = Vec::with_capacity(bit_len + 65 + zero_pad_len); + + padded_msg.extend(&msg[..]); + // append a single '1' bit + padded_msg.push(*const_one); + // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + 64) is a multiple of 512 + padded_msg.extend(vec![*const_zero; zero_pad_len]); + // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits + padded_msg.extend(processed_bit_len.into_msb0_iter().map(|bit| { + if bit { + *const_one + } else { + *const_zero + } + })); + + debug_assert!(padded_msg.len() % 512 == 0); + + let hash = padded_msg + .chunks(512) + .fold(initial_state.to_vec(), |state, msg| { + add_sha256_compress(builder, &msg[..], &state[..]) + }); + + hash +} + +/// Computes a SHA256 hash of an arbitrary length message. +/// +/// Inputs: +/// +/// 0. MSG: N-byte message +/// +/// Outputs: +/// +/// 0. HASH: 32-byte SHA2 hash +/// +/// Arguments: +/// +/// * `len`: The number of bytes to hash +pub fn sha256(len: usize) -> Arc { + let mut builder = + CircuitBuilder::new(&format!("{len}byte_sha256"), "", "0.1.0", BitOrder::Msb0); + + let msg = builder.add_input( + "MSG", + &format!("{len}-byte message"), + ValueType::Bytes, + len * 8, + ); + let const_zero = builder.add_input( + "const_zero", + "input that is always 0", + ValueType::ConstZero, + 1, + ); + let const_one = builder.add_input( + "const_one", + "input that is always 1", + ValueType::ConstOne, + 1, + ); + + let mut builder = builder.build_inputs(); + + let initial_state = SHA256_INITIAL_STATE + .into_msb0_iter() + .map(|bit| if bit { const_one[0] } else { const_zero[0] }) + .collect::>(); + + let hash = add_sha256_finalize( + &mut builder, + &msg[..], + &initial_state, + &const_zero[0], + &const_one[0], + 0, + ); + + let mut builder = builder.build_gates(); + + let hash_output = builder.add_output("HASH", "32-byte SHA2 hash", ValueType::Bytes, 256); + + builder.connect(&hash, &hash_output[..]); + + builder + .build_circuit() + .expect("failed to build sha256_finalize") +} + +#[cfg(test)] +mod tests { + use super::*; + use mpc_circuits::{circuits::test_circ, Value}; + use sha2::{Digest, Sha256}; + + #[test] + #[ignore = "expensive"] + fn test_sha256() { + let msg = [69u8; 100]; + + let circ = sha256(msg.len()); + + let mut hasher = Sha256::new(); + hasher.update(msg); + let expected = hasher.finalize().to_vec(); + + test_circ( + &circ, + &[Value::Bytes(msg.to_vec())], + &[Value::Bytes(expected)], + ); + } +} diff --git a/prf/hmac-sha256-circuits/src/verify_data.rs b/prf/hmac-sha256-circuits/src/verify_data.rs new file mode 100644 index 000000000..7241ffc65 --- /dev/null +++ b/prf/hmac-sha256-circuits/src/verify_data.rs @@ -0,0 +1,152 @@ +use std::sync::Arc; + +use mpc_circuits::{builder::CircuitBuilder, circuits::nbit_xor, BitOrder, Circuit, ValueType}; +use utils::bits::IterToBits; + +use crate::add_prf; + +/// Computes verify_data as specified in RFC 5246, Section 7.4.9. +/// +/// verify_data +/// PRF(master_secret, finished_label, Hash(handshake_messages))[0..verify_data_length-1]; +/// +/// Inputs: +/// +/// 0. OUTER_STATE: 32-byte MS outer-hash state H(ms ⊕ opad) +/// 1. INNER_STATE: 32-byte MS inner-hash state H(ms ⊕ ipad) +/// 2. HS_HASH: 32-byte handshake hash +/// 3. MASK: 12-byte mask for verify_data +/// +/// Outputs: +/// +/// 0. MASKED_VD: 12-byte masked verify_data (VD + MASK) +pub fn verify_data(label: &[u8]) -> Arc { + let label = label.to_vec(); + + let mut builder = CircuitBuilder::new("verify_data", "", "0.1.0", BitOrder::Msb0); + + let outer_hash_state = builder.add_input( + "OUTER_STATE", + "32-byte MS outer-hash state H(ms ⊕ opad)", + ValueType::Bytes, + 256, + ); + let inner_hash_state = builder.add_input( + "INNER_STATE", + "32-byte MS inner-hash state H(ms ⊕ ipad)", + ValueType::Bytes, + 256, + ); + let hs_hash = builder.add_input("HS_HASH", "32-byte handshake hash", ValueType::Bytes, 256); + let mask = builder.add_input("MASK", "12-byte mask for verify_data", ValueType::Bytes, 96); + let const_zero = builder.add_input( + "const_zero", + "input that is always 0", + ValueType::ConstZero, + 1, + ); + let const_one = builder.add_input( + "const_one", + "input that is always 1", + ValueType::ConstOne, + 1, + ); + + let mut builder = builder.build_inputs(); + + let xor = builder.add_circ(&nbit_xor(96)); + + let label = label + .into_msb0_iter() + .map(|bit| if bit { const_one[0] } else { const_zero[0] }) + .collect::>(); + + let vd = add_prf( + &mut builder, + &outer_hash_state[..], + &inner_hash_state[..], + &const_zero[0], + &const_one[0], + &label, + &hs_hash[..], + 12, + ); + + // Apply mask to vd + let masked_vd = { + builder.connect(&vd, &xor.input(0).expect("nbit_xor missing input 0")[..]); + builder.connect( + &mask[..], + &xor.input(1).expect("nbit_xor missing input 1")[..], + ); + xor.output(0).expect("nbit_xor missing output 0") + }; + + let mut builder = builder.build_gates(); + + let out_masked_vd = builder.add_output( + "MASKED_VD", + "12-byte masked verify_data", + ValueType::Bytes, + 96, + ); + + builder.connect(&masked_vd[..], &out_masked_vd[..]); + + builder + .build_circuit() + .expect("failed to build verify_data") +} + +#[cfg(test)] +mod tests { + use super::*; + use mpc_circuits::{circuits::test_circ, Value}; + + const CF_LABEL: &[u8; 15] = b"client finished"; + + #[test] + #[ignore = "expensive"] + fn test_verify_data() { + let circ = verify_data(CF_LABEL); + + println!("VD Circuit size: {}", circ.and_count()); + + let ms = [254u8; 48]; + let mask = [249u8; 12]; + let hs_hash = [99u8; 32]; + + let (ms_outer_hash_state, ms_inner_hash_state) = hmac_sha256_utils::partial_hmac(&ms); + + let vd = hmac_sha256_utils::prf(&ms, CF_LABEL, &hs_hash, 12); + + let vd_masked = vd + .iter() + .zip(mask.iter()) + .map(|(a, b)| a ^ b) + .collect::>(); + + test_circ( + &circ, + &[ + Value::Bytes( + ms_outer_hash_state + .into_iter() + .map(|v| v.to_be_bytes()) + .flatten() + .collect::>(), + ), + Value::Bytes( + ms_inner_hash_state + .into_iter() + .map(|v| v.to_be_bytes()) + .flatten() + .collect::>(), + ), + Value::Bytes(hs_hash.to_vec()), + Value::Bytes(mask.to_vec()), + ], + &[Value::Bytes(vd_masked)], + ); + } +} diff --git a/prf/hmac-sha256-core/Cargo.toml b/prf/hmac-sha256-core/Cargo.toml new file mode 100644 index 000000000..ebb01d454 --- /dev/null +++ b/prf/hmac-sha256-core/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "tlsn-hmac-sha256-core" +version = "0.1.0" +edition = "2021" + +[lib] +name = "hmac_sha256_core" + +[features] +default = [] +serde = ["dep:serde"] +build-circuits = [] + +[dependencies] +tlsn-mpc-circuits.workspace = true +tlsn-mpc-core.workspace = true +tlsn-utils.workspace = true +tlsn-hmac-sha256-utils = { path = "../hmac-sha256-utils" } + +sha2 = { workspace = true, features = ["compress"] } +digest.workspace = true +hmac.workspace = true +rand.workspace = true +thiserror.workspace = true +serde = { workspace = true, features = ["derive"], optional = true } +once_cell.workspace = true +derive_builder.workspace = true + +[dev-dependencies] +criterion.workspace = true +rand_chacha.workspace = true + +[build-dependencies] +tlsn-mpc-circuits.workspace = true +tlsn-hmac-sha256-circuits = { path = "../hmac-sha256-circuits" } +prost.workspace = true +rayon.workspace = true diff --git a/tls/tls-2pc-core/build.rs b/prf/hmac-sha256-core/build.rs similarity index 78% rename from tls/tls-2pc-core/build.rs rename to prf/hmac-sha256-core/build.rs index 2e0d8a1cf..7a5df3304 100644 --- a/tls/tls-2pc-core/build.rs +++ b/prf/hmac-sha256-core/build.rs @@ -1,19 +1,16 @@ +use hmac_sha256_circuits::{master_secret, session_keys, verify_data}; use mpc_circuits::{proto, Circuit}; use prost::Message; use rayon::prelude::*; use std::{env, fs, io, path::Path, sync::Arc}; -use tls_circuits::{c1, c2, c3, c4, c5, c6, c7}; type CircuitBuilderMap = [(&'static str, fn() -> Arc)]; static CIRCUITS: &CircuitBuilderMap = &[ - ("c1", c1), - ("c2", c2), - ("c3", c3), - ("c4", c4), - ("c5", c5), - ("c6", c6), - ("c7", c7), + ("master_secret", master_secret), + ("session_keys", session_keys), + ("cf_verify_data", || verify_data(b"client finished")), + ("sf_verify_data", || verify_data(b"server finished")), ]; fn build_circuit(name: &str, f: F) -> io::Result<()> @@ -35,9 +32,8 @@ fn build_circuits() -> io::Result<()> { CIRCUITS .into_par_iter() .filter(|(name, _)| { - let enabled = env::var(format!("CARGO_FEATURE_{}", name.to_ascii_uppercase())).is_ok(); let built = Path::new(&format!("circuits/bin/{}.bin", name)).is_file(); - enabled && (!built || force_build) + !built || force_build }) .map(|(name, f)| build_circuit(name, f)) .collect::>() diff --git a/prf/hmac-sha256-core/circuits/bin/.gitignore b/prf/hmac-sha256-core/circuits/bin/.gitignore new file mode 100644 index 000000000..f3ac583d0 --- /dev/null +++ b/prf/hmac-sha256-core/circuits/bin/.gitignore @@ -0,0 +1 @@ +*.bin \ No newline at end of file diff --git a/prf/hmac-sha256-core/src/config.rs b/prf/hmac-sha256-core/src/config.rs new file mode 100644 index 000000000..4d170ee15 --- /dev/null +++ b/prf/hmac-sha256-core/src/config.rs @@ -0,0 +1,39 @@ +use derive_builder::Builder; + +#[derive(Debug, Clone, Builder)] +pub struct PRFLeaderConfig { + id: String, + #[builder(default = "u32::MAX")] + encoder_default_stream_id: u32, +} + +impl PRFLeaderConfig { + /// Returns instance ID. + pub fn id(&self) -> &str { + &self.id + } + + /// Returns default stream ID for encoder. + pub fn encoder_default_stream_id(&self) -> u32 { + self.encoder_default_stream_id + } +} + +#[derive(Debug, Clone, Builder)] +pub struct PRFFollowerConfig { + id: String, + #[builder(default = "u32::MAX")] + encoder_default_stream_id: u32, +} + +impl PRFFollowerConfig { + /// Returns instance ID. + pub fn id(&self) -> &str { + &self.id + } + + /// Returns default stream ID for encoder. + pub fn encoder_default_stream_id(&self) -> u32 { + self.encoder_default_stream_id + } +} diff --git a/prf/hmac-sha256-core/src/lib.rs b/prf/hmac-sha256-core/src/lib.rs new file mode 100644 index 000000000..206062c3d --- /dev/null +++ b/prf/hmac-sha256-core/src/lib.rs @@ -0,0 +1,122 @@ +mod config; +pub mod mock; + +pub use config::{ + PRFFollowerConfig, PRFFollowerConfigBuilder, PRFFollowerConfigBuilderError, PRFLeaderConfig, + PRFLeaderConfigBuilder, PRFLeaderConfigBuilderError, +}; + +use mpc_circuits::Circuit; +use mpc_core::garble::{ActiveLabels, FullLabels}; +use once_cell::sync::Lazy; +use std::sync::Arc; + +/// Master secret +/// +/// Computes the master secret (MS), returning the outer and inner HMAC states. +/// +/// Outer state is H(master_secret ⊕ opad) +/// +/// Inner state is H(master_secret ⊕ ipad) +/// +/// Inputs: +/// +/// 0. PMS: 32-byte pre-master secret +/// 1. CLIENT_RAND: 32-byte client random +/// 2. SERVER_RAND: 32-byte server random +/// +/// Outputs: +/// +/// 0. OUTER_STATE: 32-byte HMAC outer hash state +/// 1. INNER_STATE: 32-byte HMAC inner hash state +pub static MS: Lazy> = Lazy::new(|| { + Circuit::load_bytes(std::include_bytes!("../circuits/bin/master_secret.bin")).unwrap() +}); + +/// Session Keys +/// +/// Compute expanded p1 which consists of client_write_key + server_write_key +/// Compute expanded p2 which consists of client_IV + server_IV +/// +/// Inputs: +/// +/// 0. OUTER_HASH_STATE: 32-byte MS outer-hash state +/// 1. INNER_HASH_STATE: 32-byte MS inner-hash state +/// 2. CLIENT_RAND: 32-byte client random +/// 3. SERVER_RAND: 32-byte server random +/// +/// Outputs: +/// +/// 0. CWK: 16-byte client write-key +/// 1. SWK: 16-byte server write-key +/// 2. CIV: 4-byte client IV +/// 3. SIV: 4-byte server IV +pub static SESSION_KEYS: Lazy> = Lazy::new(|| { + Circuit::load_bytes(std::include_bytes!("../circuits/bin/session_keys.bin")).unwrap() +}); + +/// Computes client finished verify_data as specified in RFC 5246, Section 7.4.9. +/// +/// Inputs: +/// +/// 0. OUTER_STATE: 32-byte MS outer-hash state H(ms ⊕ opad) +/// 1. INNER_STATE: 32-byte MS inner-hash state H(ms ⊕ ipad) +/// 2. HS_HASH: 32-byte handshake hash +/// 3. MASK: 12-byte mask for verify_data +/// +/// Outputs: +/// +/// 0. MASKED_VD: 12-byte masked client finished verify_data (VD + MASK) +pub static CF_VD: Lazy> = Lazy::new(|| { + Circuit::load_bytes(std::include_bytes!("../circuits/bin/cf_verify_data.bin")).unwrap() +}); + +/// Computes server finished verify_data as specified in RFC 5246, Section 7.4.9. +/// +/// Inputs: +/// +/// 0. OUTER_STATE: 32-byte MS outer-hash state H(ms ⊕ opad) +/// 1. INNER_STATE: 32-byte MS inner-hash state H(ms ⊕ ipad) +/// 2. HS_HASH: 32-byte handshake hash +/// 3. MASK: 12-byte mask for verify_data +/// +/// Outputs: +/// +/// 0. MASKED_VD: 12-byte masked server finished verify_data (VD + MASK) +pub static SF_VD: Lazy> = Lazy::new(|| { + Circuit::load_bytes(std::include_bytes!("../circuits/bin/sf_verify_data.bin")).unwrap() +}); + +#[derive(Debug, Clone)] +pub struct PmsLabels { + pub full: FullLabels, + pub active: ActiveLabels, +} + +#[derive(Debug, Clone)] +pub struct MasterSecretStateLabels { + pub full_outer_hash_state: FullLabels, + pub full_inner_hash_state: FullLabels, + pub active_outer_hash_state: ActiveLabels, + pub active_inner_hash_state: ActiveLabels, + pub full_client_random: FullLabels, + pub full_server_random: FullLabels, + pub active_client_random: ActiveLabels, + pub active_server_random: ActiveLabels, + pub full_const_zero: FullLabels, + pub full_const_one: FullLabels, + pub active_const_zero: ActiveLabels, + pub active_const_one: ActiveLabels, +} + +#[derive(Debug, Clone)] +pub struct SessionKeyLabels { + pub full_cwk: FullLabels, + pub full_swk: FullLabels, + pub full_civ: FullLabels, + pub full_siv: FullLabels, + pub active_cwk: ActiveLabels, + pub active_swk: ActiveLabels, + pub active_civ: ActiveLabels, + pub active_siv: ActiveLabels, +} diff --git a/prf/hmac-sha256-core/src/mock.rs b/prf/hmac-sha256-core/src/mock.rs new file mode 100644 index 000000000..08129a383 --- /dev/null +++ b/prf/hmac-sha256-core/src/mock.rs @@ -0,0 +1,164 @@ +use mpc_circuits::{BitOrder, Value}; +use mpc_core::garble::{ChaChaEncoder, Encoder}; + +use super::*; + +pub fn create_mock_pms_labels( + pms: [u8; 32], +) -> ((PmsLabels, PmsLabels), (ChaChaEncoder, ChaChaEncoder)) { + let mut leader_encoder = ChaChaEncoder::new([0u8; 32], BitOrder::Msb0); + let mut follower_encoder = ChaChaEncoder::new([1u8; 32], BitOrder::Msb0); + + let pms = pms.to_vec(); + + let leader_delta = leader_encoder.get_delta(); + let follower_delta = follower_encoder.get_delta(); + + let leader_rng = leader_encoder.get_stream(0); + let follower_rng = follower_encoder.get_stream(0); + + let leader_full_labels = FullLabels::generate(leader_rng, 256, Some(leader_delta)); + let follower_full_labels = FullLabels::generate(follower_rng, 256, Some(follower_delta)); + + let leader_active_labels = leader_full_labels + .select(&pms.clone().into(), BitOrder::Msb0) + .unwrap(); + let follower_active_labels = follower_full_labels + .select(&pms.into(), BitOrder::Msb0) + .unwrap(); + + let leader_pms_labels = PmsLabels { + full: leader_full_labels, + active: follower_active_labels, + }; + + let follower_pms_labels = PmsLabels { + full: follower_full_labels, + active: leader_active_labels, + }; + + ( + (leader_pms_labels, follower_pms_labels), + (leader_encoder, follower_encoder), + ) +} + +pub fn create_mock_ms_state_labels( + ms: [u8; 48], + client_random: [u8; 32], + server_random: [u8; 32], +) -> ( + (MasterSecretStateLabels, MasterSecretStateLabels), + (ChaChaEncoder, ChaChaEncoder), +) { + let mut leader_encoder = ChaChaEncoder::new([0u8; 32], BitOrder::Msb0); + let mut follower_encoder = ChaChaEncoder::new([1u8; 32], BitOrder::Msb0); + + let (outer_hash_state, inner_hash_state) = hmac_sha256_utils::partial_hmac(&ms); + + let outer_hash_state = outer_hash_state + .iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect::>(); + let inner_hash_state = inner_hash_state + .iter() + .map(|chunk| chunk.to_be_bytes()) + .flatten() + .collect::>(); + + let leader_delta = leader_encoder.get_delta(); + let follower_delta = follower_encoder.get_delta(); + + let leader_rng = leader_encoder.get_stream(0); + let follower_rng = follower_encoder.get_stream(0); + + let leader_full_outer_hash_state = FullLabels::generate(leader_rng, 256, Some(leader_delta)); + let leader_full_inner_hash_state = FullLabels::generate(leader_rng, 256, Some(leader_delta)); + let leader_full_client_random = FullLabels::generate(leader_rng, 256, Some(leader_delta)); + let leader_full_server_random = FullLabels::generate(leader_rng, 256, Some(leader_delta)); + let leader_full_const_zero = FullLabels::generate(leader_rng, 1, Some(leader_delta)); + let leader_full_const_one = FullLabels::generate(leader_rng, 1, Some(leader_delta)); + + let follower_full_outer_hash_state = + FullLabels::generate(follower_rng, 256, Some(follower_delta)); + let follower_full_inner_hash_state = + FullLabels::generate(follower_rng, 256, Some(follower_delta)); + let follower_full_client_random = FullLabels::generate(follower_rng, 256, Some(follower_delta)); + let follower_full_server_random = FullLabels::generate(follower_rng, 256, Some(follower_delta)); + let follower_full_const_zero = FullLabels::generate(follower_rng, 1, Some(follower_delta)); + let follower_full_const_one = FullLabels::generate(follower_rng, 1, Some(follower_delta)); + + let leader_active_outer_hash_state = follower_full_outer_hash_state + .select(&outer_hash_state.clone().into(), BitOrder::Msb0) + .unwrap(); + let leader_active_inner_hash_state = follower_full_inner_hash_state + .select(&inner_hash_state.clone().into(), BitOrder::Msb0) + .unwrap(); + let leader_active_client_random = follower_full_client_random + .select(&client_random.to_vec().into(), BitOrder::Msb0) + .unwrap(); + let leader_active_server_random = follower_full_server_random + .select(&server_random.to_vec().into(), BitOrder::Msb0) + .unwrap(); + let leader_active_const_zero = follower_full_const_zero + .select(&Value::ConstZero, BitOrder::Msb0) + .unwrap(); + let leader_active_const_one = follower_full_const_one + .select(&Value::ConstOne, BitOrder::Msb0) + .unwrap(); + + let follower_active_outer_hash_state = leader_full_outer_hash_state + .select(&outer_hash_state.into(), BitOrder::Msb0) + .unwrap(); + let follower_active_inner_hash_state = leader_full_inner_hash_state + .select(&inner_hash_state.into(), BitOrder::Msb0) + .unwrap(); + let follower_active_client_random = leader_full_client_random + .select(&client_random.to_vec().into(), BitOrder::Msb0) + .unwrap(); + let follower_active_server_random = leader_full_server_random + .select(&server_random.to_vec().into(), BitOrder::Msb0) + .unwrap(); + let follower_active_const_zero = leader_full_const_zero + .select(&Value::ConstZero, BitOrder::Msb0) + .unwrap(); + let follower_active_const_one = leader_full_const_one + .select(&Value::ConstOne, BitOrder::Msb0) + .unwrap(); + + let leader_ms_state_labels = MasterSecretStateLabels { + full_outer_hash_state: leader_full_outer_hash_state, + full_inner_hash_state: leader_full_inner_hash_state, + active_outer_hash_state: leader_active_outer_hash_state, + active_inner_hash_state: leader_active_inner_hash_state, + full_client_random: leader_full_client_random, + full_server_random: leader_full_server_random, + active_client_random: leader_active_client_random, + active_server_random: leader_active_server_random, + full_const_zero: leader_full_const_zero, + active_const_zero: leader_active_const_zero, + full_const_one: leader_full_const_one, + active_const_one: leader_active_const_one, + }; + + let follower_ms_state_labels = MasterSecretStateLabels { + full_outer_hash_state: follower_full_outer_hash_state, + full_inner_hash_state: follower_full_inner_hash_state, + active_outer_hash_state: follower_active_outer_hash_state, + active_inner_hash_state: follower_active_inner_hash_state, + full_client_random: follower_full_client_random, + full_server_random: follower_full_server_random, + active_client_random: follower_active_client_random, + active_server_random: follower_active_server_random, + full_const_zero: follower_full_const_zero, + active_const_zero: follower_active_const_zero, + full_const_one: follower_full_const_one, + active_const_one: follower_active_const_one, + }; + + ( + (leader_ms_state_labels, follower_ms_state_labels), + (leader_encoder, follower_encoder), + ) +} diff --git a/prf/hmac-sha256-utils/Cargo.toml b/prf/hmac-sha256-utils/Cargo.toml new file mode 100644 index 000000000..7249b63d4 --- /dev/null +++ b/prf/hmac-sha256-utils/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "tlsn-hmac-sha256-utils" +version = "0.1.0" +edition = "2021" + +[lib] +name = "hmac_sha256_utils" + +[dependencies] +hmac.workspace = true +sha2 = { workspace = true, features = ["compress"] } + +[dev-dependencies] +ring = "0.16" diff --git a/prf/hmac-sha256-utils/src/lib.rs b/prf/hmac-sha256-utils/src/lib.rs new file mode 100644 index 000000000..ada7a8619 --- /dev/null +++ b/prf/hmac-sha256-utils/src/lib.rs @@ -0,0 +1,220 @@ +//! Helper functions for HMAC-SHA256 PRF testing. + +use std::slice::from_ref; + +use hmac::{ + digest::{ + block_buffer::{BlockBuffer, Eager}, + typenum::U64, + }, + Hmac, Mac, +}; + +pub static SHA256_INITIAL_STATE: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub fn partial_sha256_digest(input: &[u8]) -> [u32; 8] { + let mut state = SHA256_INITIAL_STATE; + for b in input.chunks(64) { + let mut block = [0u8; 64]; + block[..b.len()].copy_from_slice(b); + sha2::compress256(&mut state, &[block.into()]); + } + state +} + +pub fn finalize_sha256_digest(mut state: [u32; 8], pos: usize, input: &[u8]) -> [u8; 32] { + let mut buffer = BlockBuffer::::default(); + buffer.digest_blocks(input, |b| sha2::compress256(&mut state, b)); + buffer.digest_pad( + 0x80, + &(((input.len() + pos) * 8) as u64).to_be_bytes(), + |b| sha2::compress256(&mut state, from_ref(b)), + ); + + let mut out: [u8; 32] = [0; 32]; + for (chunk, v) in out.chunks_exact_mut(4).zip(state.iter()) { + chunk.copy_from_slice(&v.to_be_bytes()); + } + out +} + +pub fn partial_hmac(key: &[u8]) -> ([u32; 8], [u32; 8]) { + let mut key_opad = [0x5cu8; 64]; + let mut key_ipad = [0x36u8; 64]; + + key_opad.iter_mut().zip(key).for_each(|(a, b)| *a ^= b); + key_ipad.iter_mut().zip(key).for_each(|(a, b)| *a ^= b); + + let outer_state = partial_sha256_digest(&key_opad); + let inner_state = partial_sha256_digest(&key_ipad); + + (outer_state, inner_state) +} + +pub fn hmac(key: &[u8], msg: &[u8]) -> Vec { + let mut hmac = Hmac::::new_from_slice(key).unwrap(); + hmac.update(msg); + hmac.finalize().into_bytes().to_vec() +} + +pub fn prf_a(key: &[u8], seed: &[u8], i: usize) -> Vec { + (0..i).fold(seed.to_vec(), |a_prev, _| hmac(key, &a_prev)) +} + +fn prf_p_hash(key: &[u8], seed: &[u8], iterations: usize) -> Vec { + (0..iterations) + .map(|i| { + let msg = { + let mut msg = prf_a(key, seed, i + 1); + msg.extend_from_slice(seed); + msg + }; + hmac(key, &msg) + }) + .flatten() + .collect() +} + +pub fn prf(key: &[u8], label: &[u8], seed: &[u8], bytes: usize) -> Vec { + let iterations = bytes / 32 + (bytes % 32 != 0) as usize; + + let mut label_seed = label.to_vec(); + label_seed.extend_from_slice(seed); + + prf_p_hash(key, &label_seed, iterations)[..bytes].to_vec() +} + +#[cfg(test)] +mod tests { + use super::*; + use hmac::{Hmac, Mac}; + use sha2::{Digest, Sha256}; + + type HmacSha256 = Hmac; + + #[test] + fn test_sha2_initial_state() { + let s = b"test string"; + + // initial state for sha2 + let state = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, + 0x5be0cd19, + ]; + let digest = finalize_sha256_digest(state, 0, s); + + let mut hasher = Sha256::new(); + hasher.update(s); + assert_eq!(digest, hasher.finalize().as_slice()); + } + + #[test] + fn test_sha2_resume_state() { + let s = b"test string test string test string test string test string test"; + + let state = partial_sha256_digest(s); + + let s2 = b"additional data "; + + let digest = finalize_sha256_digest(state, s.len(), s2); + + let mut hasher = Sha256::new(); + hasher.update(s); + hasher.update(s2); + assert_eq!(digest, hasher.finalize().as_slice()); + } + + #[test] + fn test_partial_hmac() { + let key = [42u8; 32]; + let msg = b"test string"; + let (outer_state, inner_state) = partial_hmac(&key); + + let hmac = finalize_sha256_digest( + outer_state, + 64, + finalize_sha256_digest(inner_state, 64, msg).as_slice(), + ); + + let expected_hmac: [u8; 32] = { + let mut hmac = HmacSha256::new_from_slice(&key).unwrap(); + hmac.update(msg); + hmac.finalize().into_bytes().into() + }; + + assert_eq!(hmac, expected_hmac); + } + + #[test] + fn test_hmac() { + let key = [42u8; 32]; + let msg = b"test string"; + + let hmac = hmac(&key, msg); + + let expected_hmac = { + let mut hmac = HmacSha256::new_from_slice(&key).unwrap(); + hmac.update(msg); + hmac.finalize().into_bytes().to_vec() + }; + + assert_eq!(hmac, expected_hmac); + } + + #[test] + fn test_prf() { + let key = [42u8; 32]; + let seed = [69u8; 64]; + let label = b"test label"; + + let output = prf(&key, label, &seed, 32); + + let mut expected_output = [0u8; 32]; + ring_prf::prf(&mut expected_output, &key, label, &seed); + + assert_eq!(output, expected_output); + } + + // Borrowed from Rustls for testing (we don't want ring as a dependency) + // https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs + mod ring_prf { + use ring::{hmac, hmac::HMAC_SHA256}; + + fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag { + let mut ctx = hmac::Context::with_key(key); + ctx.update(a); + ctx.update(b); + ctx.sign() + } + + fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) { + let hmac_key = hmac::Key::new(HMAC_SHA256, secret); + + // A(1) + let mut current_a = hmac::sign(&hmac_key, seed); + let chunk_size = HMAC_SHA256.digest_algorithm().output_len; + for chunk in out.chunks_mut(chunk_size) { + // P_hash[i] = HMAC_hash(secret, A(i) + seed) + let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed); + chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]); + + // A(i+1) = HMAC_hash(secret, A(i)) + current_a = hmac::sign(&hmac_key, current_a.as_ref()); + } + } + + fn concat(a: &[u8], b: &[u8]) -> Vec { + let mut ret = Vec::new(); + ret.extend_from_slice(a); + ret.extend_from_slice(b); + ret + } + + pub fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) { + let joined_seed = concat(label, seed); + p(out, secret, &joined_seed); + } + } +} diff --git a/prf/hmac-sha256/Cargo.toml b/prf/hmac-sha256/Cargo.toml new file mode 100644 index 000000000..5e86cf9eb --- /dev/null +++ b/prf/hmac-sha256/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "tlsn-hmac-sha256" +version = "0.1.0" +edition = "2021" + +[lib] +name = "hmac_sha256" + +[features] +default = ["mock"] +mock = [] + +[dependencies] +tlsn-hmac-sha256-core = { path = "../hmac-sha256-core" } +tlsn-mpc-circuits.workspace = true +tlsn-mpc-core.workspace = true +tlsn-mpc-aio.workspace = true +tlsn-utils-aio.workspace = true + +rand_chacha.workspace = true +rand.workspace = true + +async-trait.workspace = true +futures.workspace = true + +thiserror.workspace = true +derive_builder.workspace = true + +[dev-dependencies] +tlsn-utils.workspace = true +tlsn-hmac-sha256-utils = { path = "../hmac-sha256-utils" } +criterion = { workspace = true, features = ["async_tokio"] } +tokio.workspace = true + +[[bench]] +name = "mock" +harness = false diff --git a/prf/hmac-sha256/benches/mock.rs b/prf/hmac-sha256/benches/mock.rs new file mode 100644 index 000000000..f2f7032df --- /dev/null +++ b/prf/hmac-sha256/benches/mock.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use futures::lock::Mutex; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use hmac_sha256::mock; +use hmac_sha256_core::{PRFFollowerConfigBuilder, PRFLeaderConfigBuilder}; + +async fn bench_prf() { + let leader_config = PRFLeaderConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + let follower_config = PRFFollowerConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + + let (mut leader, mut follower) = mock::create_mock_prf_pair(leader_config, follower_config); + + let pms = [42u8; 32]; + + let client_random = [0u8; 32]; + let server_random = [1u8; 32]; + + let ((leader_share, follower_share), (leader_encoder, follower_encoder)) = + mock::create_mock_pms_labels(pms); + + leader.set_encoder(Arc::new(Mutex::new(leader_encoder))); + follower.set_encoder(Arc::new(Mutex::new(follower_encoder))); + + futures::join!( + async move { + leader + .compute_session_keys(client_random, server_random, leader_share) + .await + .unwrap(); + _ = leader.compute_client_finished_vd([0u8; 32]).await.unwrap(); + _ = leader.compute_server_finished_vd([0u8; 32]).await.unwrap(); + }, + async move { + follower.compute_session_keys(follower_share).await.unwrap(); + _ = follower.compute_client_finished_vd().await.unwrap(); + _ = follower.compute_server_finished_vd().await.unwrap(); + } + ); +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("prf"); + + group.bench_function("prf", |b| { + b.to_async(tokio::runtime::Runtime::new().unwrap()) + .iter(|| async { black_box(bench_prf().await) }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/prf/hmac-sha256/src/circuits/mod.rs b/prf/hmac-sha256/src/circuits/mod.rs new file mode 100644 index 000000000..cad58e91e --- /dev/null +++ b/prf/hmac-sha256/src/circuits/mod.rs @@ -0,0 +1,7 @@ +mod ms; +mod session_keys; +mod verify_data; + +pub use ms::*; +pub use session_keys::*; +pub use verify_data::*; diff --git a/prf/hmac-sha256/src/circuits/ms.rs b/prf/hmac-sha256/src/circuits/ms.rs new file mode 100644 index 000000000..1290fd92e --- /dev/null +++ b/prf/hmac-sha256/src/circuits/ms.rs @@ -0,0 +1,158 @@ +use std::sync::Arc; + +use futures::lock::Mutex; +use hmac_sha256_core::{MasterSecretStateLabels, PmsLabels, MS}; +use mpc_aio::protocol::garble::{exec::dual::DEExecute, GCError}; +use mpc_circuits::{Value, WireGroup}; +use mpc_core::garble::{ + exec::dual::DESummary, ActiveEncodedInput, ChaChaEncoder, Encoder, FullEncodedInput, + FullInputSet, +}; + +/// Executes master secret circuit as PRFLeader +/// +/// Returns master secret hash state labels +pub async fn leader_ms( + leader: DE, + encoder: Arc>, + encoder_stream_id: u32, + pms_labels: PmsLabels, + client_random: [u8; 32], + server_random: [u8; 32], +) -> Result { + let inputs = MS.inputs(); + + let client_random = inputs[1] + .clone() + .to_value(Value::Bytes(client_random.to_vec())) + .expect("client_random should be 32 bytes"); + let server_random = inputs[2] + .clone() + .to_value(Value::Bytes(server_random.to_vec())) + .expect("server_random should be 32 bytes"); + let const_zero = inputs[3] + .clone() + .to_value(Value::ConstZero) + .expect("const_zero should be 0"); + let const_one = inputs[4] + .clone() + .to_value(Value::ConstOne) + .expect("const_one should be 1"); + + let (gen_labels, cached_labels) = build_labels(encoder, encoder_stream_id, pms_labels).await; + + let summary = leader + .execute_skip_equality_check( + gen_labels, + vec![ + client_random.clone(), + server_random.clone(), + const_zero, + const_one, + ], + vec![], + vec![client_random, server_random], + cached_labels, + ) + .await?; + + let labels = build_ms_labels(summary); + + Ok(labels) +} + +/// Executes master secret circuit as PRFFollower +/// +/// Returns master secret hash state labels +pub async fn follower_ms( + follower: DE, + encoder: Arc>, + encoder_stream_id: u32, + pms_labels: PmsLabels, +) -> Result { + let inputs = MS.inputs(); + + let client_random = inputs[1].clone(); + let server_random = inputs[2].clone(); + let const_zero = inputs[3] + .clone() + .to_value(Value::ConstZero) + .expect("const_zero should be 0"); + let const_one = inputs[4] + .clone() + .to_value(Value::ConstOne) + .expect("const_one should be 1"); + + let (gen_labels, cached_labels) = build_labels(encoder, encoder_stream_id, pms_labels).await; + + let summary = follower + .execute_skip_equality_check( + gen_labels, + vec![const_zero, const_one], + vec![client_random, server_random], + vec![], + cached_labels, + ) + .await?; + + let labels = build_ms_labels(summary); + + Ok(labels) +} + +async fn build_labels( + encoder: Arc>, + encoder_stream_id: u32, + pms_labels: PmsLabels, +) -> (FullInputSet, Vec) { + let [pms, client_random, server_random, const_zero, const_one] = MS.inputs() else { + panic!("MS circuit should have 5 inputs"); + }; + + let mut encoder = encoder.lock().await; + let delta = encoder.get_delta(); + let rng = encoder.get_stream(encoder_stream_id); + + let full_pms = + FullEncodedInput::from_labels(pms.clone(), pms_labels.full).expect("pms should be valid"); + let full_client_random = FullEncodedInput::generate(rng, client_random.clone(), delta); + let full_server_random = FullEncodedInput::generate(rng, server_random.clone(), delta); + let full_const_zero = FullEncodedInput::generate(rng, const_zero.clone(), delta); + let full_const_one = FullEncodedInput::generate(rng, const_one.clone(), delta); + + let gen_labels = FullInputSet::new(vec![ + full_pms, + full_client_random, + full_server_random, + full_const_zero, + full_const_one, + ]) + .expect("Labels should be valid"); + + let pms_labels = ActiveEncodedInput::from_active_labels(pms.clone(), pms_labels.active) + .expect("pms should be 32 bytes"); + + (gen_labels, vec![pms_labels]) +} + +fn build_ms_labels(summary: DESummary) -> MasterSecretStateLabels { + let full_input_labels = summary.get_generator_summary().input_labels(); + let full_output_labels = summary.get_generator_summary().output_labels(); + let active_input_labels = summary.get_evaluator_summary().input_labels(); + let active_output_labels = summary.get_evaluator_summary().output_labels(); + + MasterSecretStateLabels { + full_outer_hash_state: full_output_labels[0].clone().into_labels(), + full_inner_hash_state: full_output_labels[1].clone().into_labels(), + active_outer_hash_state: active_output_labels[0].clone().into_labels(), + active_inner_hash_state: active_output_labels[1].clone().into_labels(), + full_client_random: full_input_labels[1].clone().into_labels(), + full_server_random: full_input_labels[2].clone().into_labels(), + active_client_random: active_input_labels[1].clone().into_labels(), + active_server_random: active_input_labels[2].clone().into_labels(), + full_const_zero: full_input_labels[3].clone().into_labels(), + full_const_one: full_input_labels[4].clone().into_labels(), + active_const_zero: active_input_labels[3].clone().into_labels(), + active_const_one: active_input_labels[4].clone().into_labels(), + } +} diff --git a/prf/hmac-sha256/src/circuits/session_keys.rs b/prf/hmac-sha256/src/circuits/session_keys.rs new file mode 100644 index 000000000..d0663a0d3 --- /dev/null +++ b/prf/hmac-sha256/src/circuits/session_keys.rs @@ -0,0 +1,276 @@ +use hmac_sha256_core::{MasterSecretStateLabels, SessionKeyLabels, SESSION_KEYS}; +use mpc_aio::protocol::garble::{exec::dual::DEExecute, GCError}; +use mpc_core::garble::{ActiveEncodedInput, FullEncodedInput, FullInputSet}; + +/// Executes session_keys as PRFLeader +/// +/// Returns session key shares +pub async fn session_keys( + de: DE, + ms_labels: MasterSecretStateLabels, +) -> Result { + let (gen_labels, cached_labels) = build_labels(ms_labels); + let gen_inputs = vec![]; + let ot_send_inputs = vec![]; + let ot_receive_inputs = vec![]; + + let de_summary = de + .execute_skip_equality_check( + gen_labels, + gen_inputs, + ot_send_inputs, + ot_receive_inputs, + cached_labels, + ) + .await?; + + let full_output_labels = de_summary.get_generator_summary().output_labels(); + let full_cwk = full_output_labels[0].clone().into_labels(); + let full_swk = full_output_labels[1].clone().into_labels(); + let full_civ = full_output_labels[2].clone().into_labels(); + let full_siv = full_output_labels[3].clone().into_labels(); + + let active_output_labels = de_summary.get_evaluator_summary().output_labels(); + let active_cwk = active_output_labels[0].clone().into_labels(); + let active_swk = active_output_labels[1].clone().into_labels(); + let active_civ = active_output_labels[2].clone().into_labels(); + let active_siv = active_output_labels[3].clone().into_labels(); + + Ok(SessionKeyLabels { + full_cwk, + full_swk, + full_civ, + full_siv, + active_cwk, + active_swk, + active_civ, + active_siv, + }) +} + +fn build_labels(ms_labels: MasterSecretStateLabels) -> (FullInputSet, Vec) { + let [outer_hash_state_input, inner_hash_state_input, client_random_input, server_random_input, const_zero, const_one] = SESSION_KEYS.inputs() else { + panic!("session_keys circuit should have 6 inputs"); + }; + + let full_outer_hash_state = FullEncodedInput::from_labels( + outer_hash_state_input.clone(), + ms_labels.full_outer_hash_state, + ) + .expect("outer_hash_state should be valid"); + + let full_inner_hash_state = FullEncodedInput::from_labels( + inner_hash_state_input.clone(), + ms_labels.full_inner_hash_state, + ) + .expect("inner_hash_state should be valid"); + + let full_client_random = + FullEncodedInput::from_labels(client_random_input.clone(), ms_labels.full_client_random) + .expect("client_random should be valid"); + + let full_server_random = + FullEncodedInput::from_labels(server_random_input.clone(), ms_labels.full_server_random) + .expect("server_random should be valid"); + + let full_const_zero = + FullEncodedInput::from_labels(const_zero.clone(), ms_labels.full_const_zero) + .expect("const_zero should be valid"); + + let full_const_one = FullEncodedInput::from_labels(const_one.clone(), ms_labels.full_const_one) + .expect("const_one should be valid"); + + let gen_labels = FullInputSet::new(vec![ + full_outer_hash_state, + full_inner_hash_state, + full_client_random, + full_server_random, + full_const_zero, + full_const_one, + ]) + .expect("Labels should be valid"); + + let active_outer_hash_state = ActiveEncodedInput::from_active_labels( + outer_hash_state_input.clone(), + ms_labels.active_outer_hash_state, + ) + .expect("outer_hash_state should be valid"); + + let active_inner_hash_state = ActiveEncodedInput::from_active_labels( + inner_hash_state_input.clone(), + ms_labels.active_inner_hash_state, + ) + .expect("inner_hash_state should be valid"); + + let active_client_random = ActiveEncodedInput::from_active_labels( + client_random_input.clone(), + ms_labels.active_client_random, + ) + .expect("client_random should be valid"); + + let active_server_random = ActiveEncodedInput::from_active_labels( + server_random_input.clone(), + ms_labels.active_server_random, + ) + .expect("server_random should be valid"); + + let active_const_zero = + ActiveEncodedInput::from_active_labels(const_zero.clone(), ms_labels.active_const_zero) + .expect("const_zero should be valid"); + + let active_const_one = + ActiveEncodedInput::from_active_labels(const_one.clone(), ms_labels.active_const_one) + .expect("const_one should be valid"); + + let cached_labels = vec![ + active_outer_hash_state, + active_inner_hash_state, + active_client_random, + active_server_random, + active_const_zero, + active_const_one, + ]; + + (gen_labels, cached_labels) +} + +#[cfg(test)] +mod tests { + use super::*; + + use hmac_sha256_core::mock::create_mock_ms_state_labels; + use mpc_aio::protocol::garble::exec::dual::mock::mock_dualex_pair; + use mpc_core::garble::exec::dual::DualExConfigBuilder; + use utils::bits::FromBits; + + #[ignore = "expensive"] + #[tokio::test] + async fn test_session_keys() { + let de_config = DualExConfigBuilder::default() + .id("test".to_string()) + .circ(SESSION_KEYS.clone()) + .build() + .expect("DE config should be valid"); + let (gc_leader, gc_follower) = mock_dualex_pair(de_config); + + let ms = [42u8; 48]; + let client_random = [69u8; 32]; + let server_random = [96u8; 32]; + + let ((leader_labels, follower_labels), _) = + create_mock_ms_state_labels(ms, client_random, server_random); + + let (leader_keys, follower_keys) = tokio::try_join!( + session_keys(gc_leader, leader_labels), + session_keys(gc_follower, follower_labels) + ) + .unwrap(); + + let ( + leader_cwk, + leader_swk, + leader_civ, + leader_siv, + follower_cwk, + follower_swk, + follower_civ, + follower_siv, + ) = decode_keys(leader_keys, follower_keys); + + assert_eq!(leader_cwk, follower_cwk); + assert_eq!(leader_swk, follower_swk); + assert_eq!(leader_civ, follower_civ); + assert_eq!(leader_siv, follower_siv); + + let seed = server_random + .iter() + .chain(client_random.iter()) + .copied() + .collect::>(); + let expected_key_material = hmac_sha256_utils::prf(&ms, b"key expansion", &seed, 40); + + let expected_cwk = expected_key_material[0..16].to_vec(); + let expected_swk = expected_key_material[16..32].to_vec(); + let expected_civ = expected_key_material[32..36].to_vec(); + let expected_siv = expected_key_material[36..40].to_vec(); + + assert_eq!(leader_cwk, expected_cwk); + assert_eq!(leader_swk, expected_swk); + assert_eq!(leader_civ, expected_civ); + assert_eq!(leader_siv, expected_siv); + } + + fn decode_keys( + leader_keys: SessionKeyLabels, + follower_keys: SessionKeyLabels, + ) -> ( + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + ) { + let leader_cwk = Vec::::from_msb0( + leader_keys + .active_cwk + .decode(follower_keys.full_cwk.get_decoding()) + .unwrap(), + ); + let leader_swk = Vec::::from_msb0( + leader_keys + .active_swk + .decode(follower_keys.full_swk.get_decoding()) + .unwrap(), + ); + let leader_civ = Vec::::from_msb0( + leader_keys + .active_civ + .decode(follower_keys.full_civ.get_decoding()) + .unwrap(), + ); + let leader_siv = Vec::::from_msb0( + leader_keys + .active_siv + .decode(follower_keys.full_siv.get_decoding()) + .unwrap(), + ); + let follower_cwk = Vec::::from_msb0( + follower_keys + .active_cwk + .decode(leader_keys.full_cwk.get_decoding()) + .unwrap(), + ); + let follower_swk = Vec::::from_msb0( + follower_keys + .active_swk + .decode(leader_keys.full_swk.get_decoding()) + .unwrap(), + ); + let follower_civ = Vec::::from_msb0( + follower_keys + .active_civ + .decode(leader_keys.full_civ.get_decoding()) + .unwrap(), + ); + let follower_siv = Vec::::from_msb0( + follower_keys + .active_siv + .decode(leader_keys.full_siv.get_decoding()) + .unwrap(), + ); + + ( + leader_cwk, + leader_swk, + leader_civ, + leader_siv, + follower_cwk, + follower_swk, + follower_civ, + follower_siv, + ) + } +} diff --git a/prf/hmac-sha256/src/circuits/verify_data.rs b/prf/hmac-sha256/src/circuits/verify_data.rs new file mode 100644 index 000000000..e143c5813 --- /dev/null +++ b/prf/hmac-sha256/src/circuits/verify_data.rs @@ -0,0 +1,226 @@ +use std::sync::Arc; + +use futures::lock::Mutex; +use hmac_sha256_core::MasterSecretStateLabels; +use mpc_aio::protocol::garble::{exec::dual::DEExecute, GCError}; +use mpc_circuits::{Circuit, Value, WireGroup}; +use mpc_core::garble::{ + ActiveEncodedInput, ChaChaEncoder, Encoder, FullEncodedInput, FullInputSet, +}; +use rand::Rng; + +/// Executes pms as PRFLeader +/// +/// Returns inner_hash_state +pub async fn leader_verify_data( + leader: DE, + circ: &Circuit, + encoder: Arc>, + encoder_stream_id: u32, + ms_state_labels: MasterSecretStateLabels, + handshake_hash: [u8; 32], +) -> Result<[u8; 12], GCError> { + let [_, _, handshake_hash_input, mask_input, const_zero, const_one] = circ.inputs() else { + panic!("Verify data circuit should have 6 inputs"); + }; + + let handshake_hash_value = handshake_hash_input + .clone() + .to_value(handshake_hash.iter().copied().rev().collect::>()) + .expect("handshake_hash should be 32 bytes"); + let mask: Vec = rand::thread_rng().gen::<[u8; 12]>().to_vec(); + let mask_value = mask_input + .clone() + .to_value(mask.clone()) + .expect("MASK should be 12 bytes"); + let const_zero_value = const_zero + .clone() + .to_value(Value::ConstZero) + .expect("const_zero should be 0"); + let const_one_value = const_one + .clone() + .to_value(Value::ConstOne) + .expect("const_one should be 1"); + + let (gen_labels, cached_labels) = + build_labels(circ, encoder, encoder_stream_id, ms_state_labels).await; + + let output = leader + .execute( + gen_labels, + vec![ + handshake_hash_value.clone(), + mask_value.clone(), + const_zero_value, + const_one_value, + ], + vec![], + vec![handshake_hash_value, mask_value], + cached_labels, + ) + .await?; + + let Value::Bytes(masked_vd) = output[0].value().clone() else { + panic!("verify_data output 0 should be bytes"); + }; + + // Remove mask + let vd = masked_vd + .iter() + .zip(mask.iter()) + .map(|(a, b)| a ^ b) + .collect::>(); + + Ok(vd.try_into().expect("verify_data should be 12 bytes")) +} + +/// Executes verify_data as PRFFollower +pub async fn follower_verify_data( + follower: DE, + circ: &Circuit, + encoder: Arc>, + encoder_stream_id: u32, + ms_state_labels: MasterSecretStateLabels, +) -> Result<(), GCError> { + let [_, _, handshake_hash_input, mask_input, const_zero, const_one] = circ.inputs() else { + panic!("Verify data circuit should have 6 inputs"); + }; + + let const_zero_value = const_zero + .clone() + .to_value(Value::ConstZero) + .expect("const_zero should be 0"); + let const_one_value = const_one + .clone() + .to_value(Value::ConstOne) + .expect("const_one should be 1"); + + let (gen_labels, cached_labels) = + build_labels(circ, encoder, encoder_stream_id, ms_state_labels).await; + + _ = follower + .execute( + gen_labels, + vec![const_zero_value, const_one_value], + vec![handshake_hash_input.clone(), mask_input.clone()], + vec![], + cached_labels, + ) + .await?; + + Ok(()) +} + +async fn build_labels( + circ: &Circuit, + encoder: Arc>, + encoder_stream_id: u32, + ms_state_labels: MasterSecretStateLabels, +) -> (FullInputSet, Vec) { + let [ms_outer_hash_state, ms_inner_hash_state, handshake_hash, mask, const_zero, const_one] = circ.inputs() else { + panic!("Verify data circuit should have 6 inputs"); + }; + + let full_ms_outer_hash_state_labels = FullEncodedInput::from_labels( + ms_outer_hash_state.clone(), + ms_state_labels.full_outer_hash_state.clone(), + ) + .expect("ms_outer_hash_state_labels should be valid"); + + let full_ms_inner_hash_state_labels = FullEncodedInput::from_labels( + ms_inner_hash_state.clone(), + ms_state_labels.full_inner_hash_state.clone(), + ) + .expect("ms_inner_hash_state_labels should be valid"); + + let mut encoder = encoder.lock().await; + let delta = encoder.get_delta(); + let rng = encoder.get_stream(encoder_stream_id); + + let handshake_hash_labels = FullEncodedInput::generate(rng, handshake_hash.clone(), delta); + let mask_labels = FullEncodedInput::generate(rng, mask.clone(), delta); + let const_zero_labels = FullEncodedInput::generate(rng, const_zero.clone(), delta); + let const_one_labels = FullEncodedInput::generate(rng, const_one.clone(), delta); + + let gen_labels = FullInputSet::new(vec![ + full_ms_outer_hash_state_labels, + full_ms_inner_hash_state_labels, + handshake_hash_labels, + mask_labels, + const_zero_labels, + const_one_labels, + ]) + .expect("All labels should be valid"); + + let active_ms_outer_hash_state_labels = ActiveEncodedInput::from_labels( + ms_outer_hash_state.clone(), + ms_state_labels.active_outer_hash_state.clone(), + ) + .expect("ms_outer_hash_state_labels should be valid"); + + let active_ms_inner_hash_state_labels = ActiveEncodedInput::from_labels( + ms_inner_hash_state.clone(), + ms_state_labels.active_inner_hash_state.clone(), + ) + .expect("ms_inner_hash_state_labels should be valid"); + + ( + gen_labels, + vec![ + active_ms_outer_hash_state_labels, + active_ms_inner_hash_state_labels, + ], + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::mock::create_mock_ms_state_labels; + use hmac_sha256_core::CF_VD; + use mpc_aio::protocol::garble::exec::dual::mock::mock_dualex_pair; + use mpc_core::garble::exec::dual::DualExConfigBuilder; + + #[ignore = "expensive"] + #[tokio::test] + async fn test_vd() { + let de_config = DualExConfigBuilder::default() + .id("test".to_string()) + .circ(CF_VD.clone()) + .build() + .expect("DE config should be valid"); + let (gc_leader, gc_follower) = mock_dualex_pair(de_config); + + let ms = [69u8; 48]; + let client_random = [42u8; 32]; + let server_random = [43u8; 32]; + let hs_hash = [99u8; 32]; + + let ((leader_labels, follower_labels), (leader_encoder, follower_encoder)) = + create_mock_ms_state_labels(ms, client_random, server_random); + + let expected_vd = hmac_sha256_utils::prf(&ms, b"client finished", &hs_hash, 12); + + let (vd, _) = tokio::try_join!( + leader_verify_data( + gc_leader, + &CF_VD, + Arc::new(Mutex::new(leader_encoder)), + 0, + leader_labels, + hs_hash + ), + follower_verify_data( + gc_follower, + &CF_VD, + Arc::new(Mutex::new(follower_encoder)), + 0, + follower_labels + ) + ) + .unwrap(); + + assert_eq!(vd.to_vec(), expected_vd); + } +} diff --git a/prf/hmac-sha256/src/follower.rs b/prf/hmac-sha256/src/follower.rs new file mode 100644 index 000000000..c8c3e93b6 --- /dev/null +++ b/prf/hmac-sha256/src/follower.rs @@ -0,0 +1,183 @@ +use std::{marker::PhantomData, sync::Arc}; + +use async_trait::async_trait; +use follower_core::{CF_VD, SF_VD}; +use futures::lock::Mutex; + +use hmac_sha256_core::{ + self as follower_core, PRFFollowerConfig, PmsLabels, SessionKeyLabels, MS, SESSION_KEYS, +}; +use mpc_aio::protocol::garble::{exec::dual::DEExecute, factory::GCFactoryError}; +use mpc_core::garble::{ + exec::dual::{DualExConfig, DualExConfigBuilder}, + ChaChaEncoder, +}; +use utils_aio::factory::AsyncFactory; + +use crate::{circuits, PRFFollow, State}; + +use super::PRFError; + +pub struct PRFFollower +where + DEF: AsyncFactory, + DE: DEExecute + Send, +{ + config: PRFFollowerConfig, + state: State, + + encoder: Option>>, + + de_factory: DEF, + + _de: PhantomData, +} + +impl PRFFollower +where + DEF: AsyncFactory + Send, + DE: DEExecute + Send, +{ + pub fn new(config: PRFFollowerConfig, de_factory: DEF) -> PRFFollower { + PRFFollower { + config, + state: State::SessionKeys, + encoder: None, + de_factory, + _de: PhantomData, + } + } + + pub fn set_encoder(&mut self, encoder: Arc>) { + self.encoder = Some(encoder); + } + + pub async fn compute_session_keys( + &mut self, + pms_labels: PmsLabels, + ) -> Result { + let state = std::mem::replace(&mut self.state, State::Error); + let encoder = self.encoder.clone().unwrap(); + + let State::SessionKeys = state else { + return Err(PRFError::InvalidState(state)); + }; + + // TODO: Set up this stuff concurrently + let id = format!("{}/ms", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(MS.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_ms = self.de_factory.create(id, de_config).await?; + + let id = format!("{}/ke", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(SESSION_KEYS.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_ke = self.de_factory.create(id, de_config).await?; + + let ms_state_labels = circuits::follower_ms( + de_ms, + encoder, + self.config.encoder_default_stream_id(), + pms_labels, + ) + .await?; + + let session_key_labels = circuits::session_keys(de_ke, ms_state_labels.clone()).await?; + + self.state = State::ClientFinished { + ms_hash_state_labels: ms_state_labels, + }; + + Ok(session_key_labels) + } + + pub async fn compute_client_finished_vd(&mut self) -> Result<(), PRFError> { + let state = std::mem::replace(&mut self.state, State::Error); + let encoder = self.encoder.clone().unwrap(); + + let State::ClientFinished { ms_hash_state_labels } = state else { + return Err(PRFError::InvalidState(state)); + }; + + let id = format!("{}/cf", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(CF_VD.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_cf = self.de_factory.create(id, de_config).await?; + + circuits::follower_verify_data( + de_cf, + &CF_VD, + encoder, + self.config.encoder_default_stream_id(), + ms_hash_state_labels.clone(), + ) + .await?; + + self.state = State::ServerFinished { + ms_hash_state_labels, + }; + + Ok(()) + } + + pub async fn compute_server_finished_vd(&mut self) -> Result<(), PRFError> { + let state = std::mem::replace(&mut self.state, State::Error); + let encoder = self.encoder.clone().unwrap(); + + let State::ServerFinished { ms_hash_state_labels } = state else { + return Err(PRFError::InvalidState(state)); + }; + + let id = format!("{}/sf", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(SF_VD.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_sf = self.de_factory.create(id, de_config).await?; + + circuits::follower_verify_data( + de_sf, + &SF_VD, + encoder, + self.config.encoder_default_stream_id(), + ms_hash_state_labels.clone(), + ) + .await?; + + self.state = State::Complete; + + Ok(()) + } +} + +#[async_trait] +impl PRFFollow for PRFFollower +where + DEF: AsyncFactory + Send, + DE: DEExecute + Send, +{ + async fn compute_session_keys( + &mut self, + pms_labels: PmsLabels, + ) -> Result { + self.compute_session_keys(pms_labels).await + } + + async fn compute_client_finished_vd(&mut self) -> Result<(), PRFError> { + self.compute_client_finished_vd().await + } + + async fn compute_server_finished_vd(&mut self) -> Result<(), PRFError> { + self.compute_server_finished_vd().await + } +} diff --git a/prf/hmac-sha256/src/leader.rs b/prf/hmac-sha256/src/leader.rs new file mode 100644 index 000000000..280f423ad --- /dev/null +++ b/prf/hmac-sha256/src/leader.rs @@ -0,0 +1,208 @@ +use std::{marker::PhantomData, sync::Arc}; + +use async_trait::async_trait; +use futures::lock::Mutex; + +use hmac_sha256_core::{ + self as leader_core, PRFLeaderConfig, PmsLabels, SessionKeyLabels, MS, SESSION_KEYS, +}; +use leader_core::{CF_VD, SF_VD}; +use mpc_aio::protocol::garble::{exec::dual::DEExecute, factory::GCFactoryError}; +use mpc_core::garble::{ + exec::dual::{DualExConfig, DualExConfigBuilder}, + ChaChaEncoder, +}; +use utils_aio::factory::AsyncFactory; + +use crate::{circuits, PRFLead, State}; + +use super::PRFError; + +pub struct PRFLeader +where + DEF: AsyncFactory, + DE: DEExecute + Send, +{ + config: PRFLeaderConfig, + state: State, + + encoder: Option>>, + + de_factory: DEF, + + _de: PhantomData, +} + +impl PRFLeader +where + DEF: AsyncFactory + Send, + DE: DEExecute + Send, +{ + pub fn new(config: PRFLeaderConfig, de_factory: DEF) -> PRFLeader { + PRFLeader { + config, + state: State::SessionKeys, + encoder: None, + de_factory, + _de: PhantomData, + } + } + + pub fn set_encoder(&mut self, encoder: Arc>) { + self.encoder = Some(encoder); + } + + /// Computes leader's shares of the TLS session keys using the session randoms and their + /// share of the PMS + /// + /// Returns session key shares + pub async fn compute_session_keys( + &mut self, + client_random: [u8; 32], + server_random: [u8; 32], + pms_labels: PmsLabels, + ) -> Result { + let state = std::mem::replace(&mut self.state, State::Error); + let encoder = self.encoder.clone().ok_or(PRFError::EncoderNotSet)?; + + let State::SessionKeys = state else { + return Err(PRFError::InvalidState(state)); + }; + + // TODO: Set up this stuff concurrently + let id = format!("{}/ms", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(MS.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_ms = self.de_factory.create(id, de_config).await?; + + let id = format!("{}/ke", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(SESSION_KEYS.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_ke = self.de_factory.create(id, de_config).await?; + + let ms_state_labels = circuits::leader_ms( + de_ms, + encoder, + self.config.encoder_default_stream_id(), + pms_labels, + client_random, + server_random, + ) + .await?; + + let session_key_labels = circuits::session_keys(de_ke, ms_state_labels.clone()).await?; + + self.state = State::ClientFinished { + ms_hash_state_labels: ms_state_labels, + }; + + Ok(session_key_labels) + } + + pub async fn compute_client_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PRFError> { + let state = std::mem::replace(&mut self.state, State::Error); + let encoder = self.encoder.clone().ok_or(PRFError::EncoderNotSet)?; + + let State::ClientFinished { ms_hash_state_labels } = state else { + return Err(PRFError::InvalidState(state)); + }; + + let id = format!("{}/cf", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(CF_VD.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_cf = self.de_factory.create(id, de_config).await?; + + let vd = circuits::leader_verify_data( + de_cf, + &CF_VD, + encoder, + self.config.encoder_default_stream_id(), + ms_hash_state_labels.clone(), + handshake_hash, + ) + .await?; + + self.state = State::ServerFinished { + ms_hash_state_labels, + }; + + Ok(vd) + } + + pub async fn compute_server_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PRFError> { + let state = std::mem::replace(&mut self.state, State::Error); + let encoder = self.encoder.clone().ok_or(PRFError::EncoderNotSet)?; + + let State::ServerFinished { ms_hash_state_labels } = state else { + return Err(PRFError::InvalidState(state)); + }; + + let id = format!("{}/sf", self.config.id()); + let de_config = DualExConfigBuilder::default() + .id(id.clone()) + .circ(SF_VD.clone()) + .build() + .expect("DualExConfig should be valid"); + let de_sf = self.de_factory.create(id, de_config).await?; + + let vd = circuits::leader_verify_data( + de_sf, + &SF_VD, + encoder, + self.config.encoder_default_stream_id(), + ms_hash_state_labels.clone(), + handshake_hash, + ) + .await?; + + self.state = State::Complete; + + Ok(vd) + } +} + +#[async_trait] +impl PRFLead for PRFLeader +where + DEF: AsyncFactory + Send, + DE: DEExecute + Send, +{ + async fn compute_session_keys( + &mut self, + client_random: [u8; 32], + server_random: [u8; 32], + pms_labels: PmsLabels, + ) -> Result { + self.compute_session_keys(client_random, server_random, pms_labels) + .await + } + + async fn compute_client_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PRFError> { + self.compute_client_finished_vd(handshake_hash).await + } + + async fn compute_server_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PRFError> { + self.compute_server_finished_vd(handshake_hash).await + } +} diff --git a/prf/hmac-sha256/src/lib.rs b/prf/hmac-sha256/src/lib.rs new file mode 100644 index 000000000..0dc39130b --- /dev/null +++ b/prf/hmac-sha256/src/lib.rs @@ -0,0 +1,218 @@ +//! This module contains the protocol for computing TLS SHA-256 HMAC PRF using 2PC in such a way +//! that neither party learns the session keys, rather, they learn encodings of the keys which can +//! be used in subsequent computations. + +pub(crate) mod circuits; +mod follower; +mod leader; + +use async_trait::async_trait; + +use hmac_sha256_core::{MasterSecretStateLabels, SessionKeyLabels}; +use mpc_aio::protocol::garble::GCError; + +pub use follower::PRFFollower; +pub use leader::PRFLeader; + +pub use hmac_sha256_core::{ + PRFFollowerConfig, PRFFollowerConfigBuilder, PRFFollowerConfigBuilderError, PRFLeaderConfig, + PRFLeaderConfigBuilder, PRFLeaderConfigBuilderError, PmsLabels, +}; + +#[derive(Debug, Clone)] +pub enum State { + SessionKeys, + ClientFinished { + ms_hash_state_labels: MasterSecretStateLabels, + }, + ServerFinished { + ms_hash_state_labels: MasterSecretStateLabels, + }, + Complete, + Error, +} + +#[derive(Debug, thiserror::Error)] +pub enum PRFError { + #[error("GCError: {0}")] + GCError(#[from] GCError), + #[error("GCFactoryError: {0}")] + GCFactoryError(#[from] mpc_aio::protocol::garble::factory::GCFactoryError), + #[error("IO Error: {0}")] + IOError(#[from] std::io::Error), + #[error("MuxerError: {0}")] + MuxerError(#[from] utils_aio::mux::MuxerError), + #[error("Encoder not set")] + EncoderNotSet, + #[error("Invalid state: {0:?}")] + InvalidState(State), +} + +#[async_trait] +pub trait PRFLead { + async fn compute_session_keys( + &mut self, + client_random: [u8; 32], + server_random: [u8; 32], + pms_labels: PmsLabels, + ) -> Result; + + async fn compute_client_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PRFError>; + + async fn compute_server_finished_vd( + &mut self, + handshake_hash: [u8; 32], + ) -> Result<[u8; 12], PRFError>; +} + +#[async_trait] +pub trait PRFFollow { + async fn compute_session_keys( + &mut self, + pms_labels: PmsLabels, + ) -> Result; + + async fn compute_client_finished_vd(&mut self) -> Result<(), PRFError>; + + async fn compute_server_finished_vd(&mut self) -> Result<(), PRFError>; +} + +pub mod mock { + use hmac_sha256_core::{PRFFollowerConfig, PRFLeaderConfig}; + use mpc_aio::protocol::garble::{ + exec::dual::mock::{MockDualExFollower, MockDualExLeader}, + factory::dual::mock::{create_mock_dualex_factory, MockDualExFactory}, + }; + + pub use hmac_sha256_core::mock::*; + + use crate::{PRFFollower, PRFLeader}; + + pub fn create_mock_prf_pair( + leader_config: PRFLeaderConfig, + follower_config: PRFFollowerConfig, + ) -> ( + PRFLeader, + PRFFollower, + ) { + let de_factory = create_mock_dualex_factory(); + + let leader = PRFLeader::::new( + leader_config, + de_factory.clone(), + ); + let follower = + PRFFollower::::new(follower_config, de_factory); + + (leader, follower) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use futures::lock::Mutex; + use hmac_sha256_core::{PRFFollowerConfigBuilder, PRFLeaderConfigBuilder}; + + use super::*; + use mock::*; + + #[ignore = "expensive"] + #[tokio::test] + async fn test_prf() { + let leader_config = PRFLeaderConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + let follower_config = PRFFollowerConfigBuilder::default() + .id("test".to_string()) + .build() + .unwrap(); + + let (mut leader, mut follower) = create_mock_prf_pair(leader_config, follower_config); + + let pms = [42u8; 32]; + let client_random = [69u8; 32]; + let server_random: [u8; 32] = [96u8; 32]; + let cf_hs_hash: [u8; 32] = [1u8; 32]; + let sf_hs_hash: [u8; 32] = [2u8; 32]; + let seed = client_random + .iter() + .chain(&server_random) + .copied() + .collect::>(); + let ms = hmac_sha256_utils::prf(&pms, b"master secret", &seed, 48); + + let ((leader_share, follower_share), (leader_encoder, follower_encoder)) = + create_mock_pms_labels(pms); + + leader.set_encoder(Arc::new(Mutex::new(leader_encoder))); + follower.set_encoder(Arc::new(Mutex::new(follower_encoder))); + + let (leader_keys, follower_keys) = tokio::try_join!( + leader.compute_session_keys(client_random, server_random, leader_share), + follower.compute_session_keys(follower_share) + ) + .unwrap(); + + let leader_cwk = leader_keys + .active_cwk + .decode(follower_keys.full_cwk.get_decoding()) + .unwrap(); + let leader_swk = leader_keys + .active_swk + .decode(follower_keys.full_swk.get_decoding()) + .unwrap(); + let leader_civ = leader_keys + .active_civ + .decode(follower_keys.full_civ.get_decoding()) + .unwrap(); + let leader_siv = leader_keys + .active_siv + .decode(follower_keys.full_siv.get_decoding()) + .unwrap(); + let follower_cwk = follower_keys + .active_cwk + .decode(leader_keys.full_cwk.get_decoding()) + .unwrap(); + let follower_swk = follower_keys + .active_swk + .decode(leader_keys.full_swk.get_decoding()) + .unwrap(); + let follower_civ = follower_keys + .active_civ + .decode(leader_keys.full_civ.get_decoding()) + .unwrap(); + let follower_siv = follower_keys + .active_siv + .decode(leader_keys.full_siv.get_decoding()) + .unwrap(); + + assert_eq!(leader_cwk, follower_cwk); + assert_eq!(leader_swk, follower_swk); + assert_eq!(leader_civ, follower_civ); + assert_eq!(leader_siv, follower_siv); + + let (leader_cf_vd, _) = tokio::try_join!( + leader.compute_client_finished_vd(cf_hs_hash), + follower.compute_client_finished_vd() + ) + .unwrap(); + + let expected_cf_vd = hmac_sha256_utils::prf(&ms, b"client finished", &cf_hs_hash, 12); + assert_eq!(leader_cf_vd.to_vec(), expected_cf_vd); + + let (leader_sf_vd, _) = tokio::try_join!( + leader.compute_server_finished_vd(sf_hs_hash), + follower.compute_server_finished_vd() + ) + .unwrap(); + + let expected_sf_vd = hmac_sha256_utils::prf(&ms, b"server finished", &sf_hs_hash, 12); + assert_eq!(leader_sf_vd.to_vec(), expected_sf_vd); + } +} diff --git a/tls/Cargo.toml b/tls/Cargo.toml index 734c3c170..20fd553c5 100644 --- a/tls/Cargo.toml +++ b/tls/Cargo.toml @@ -1,6 +1,6 @@ [workspace] members = ["tls-core"] -exclude = ["tls-client", "tls-circuits", "tls-2pc-aio", "tls-2pc-core"] +exclude = ["tls-client"] [workspace.dependencies] # tlsn diff --git a/tls/tls-2pc-aio/Cargo.toml b/tls/tls-2pc-aio/Cargo.toml deleted file mode 100644 index f7be6202b..000000000 --- a/tls/tls-2pc-aio/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -name = "tlsn-tls-2pc-aio" -version = "0.1.0" -edition = "2021" - -[lib] -name = "tls_2pc_aio" - -[features] -default = ["mock"] -mock = ["dep:share-conversion-core", "dep:rand_chacha", "dep:rand"] - -[dependencies] -thiserror.workspace = true -tlsn-tls-2pc-core = { path = "../tls-2pc-core" } -share-conversion-aio = { path = "../../mpc/share-conversion-aio" } -share-conversion-core = { path = "../../mpc/share-conversion-core", optional = true } -rand_chacha = { workspace = true, optional = true } -rand = { workspace = true, optional = true } - -[dev-dependencies] -tokio.workspace = true -ghash_rc.workspace = true - -# [[bench]] -# name = "prf" -# harness = false diff --git a/tls/tls-2pc-aio/benches/prf.rs b/tls/tls-2pc-aio/benches/prf.rs deleted file mode 100644 index 5527e9787..000000000 --- a/tls/tls-2pc-aio/benches/prf.rs +++ /dev/null @@ -1,43 +0,0 @@ -// use criterion::{black_box, criterion_group, criterion_main, Criterion}; -// use mpc_aio::protocol::{garble::exec::dual::mock_dualex_pair, point_addition::P256SecretShare}; -// use tls_2pc_aio::prf::{PRFFollower, PRFLeader, PRFMessage}; -// use tokio::runtime::Runtime; -// use utils_aio::duplex::DuplexChannel; -// -// async fn run_prf() { -// let (leader_channel, follower_channel) = DuplexChannel::::new(); -// let (gc_leader, gc_follower) = mock_dualex_pair(); -// let leader = PRFLeader::new(Box::new(leader_channel), gc_leader); -// let follower = PRFFollower::new(Box::new(follower_channel), gc_follower); -// -// let (task_leader, task_follower) = tokio::join!( -// tokio::task::spawn_blocking(move || { -// futures::executor::block_on(leader.compute_session_keys( -// [0u8; 32], -// [0u8; 32], -// P256SecretShare::new([0u8; 32]), -// )) -// }), -// tokio::task::spawn_blocking(move || { -// futures::executor::block_on( -// follower.compute_session_keys(P256SecretShare::new([0u8; 32])), -// ) -// }) -// ); -// -// let leader_keys = task_leader.unwrap().unwrap(); -// let follower_keys = task_follower.unwrap().unwrap(); -// -// black_box((leader_keys, follower_keys)); -// } -// -// fn criterion_benchmark(c: &mut Criterion) { -// let mut group = c.benchmark_group("prf"); -// -// group.bench_function("run_prf", |b| { -// b.to_async(Runtime::new().unwrap()).iter(|| run_prf()); -// }); -// } -// -// criterion_group!(benches, criterion_benchmark); -// criterion_main!(benches); diff --git a/tls/tls-2pc-aio/src/conn/builder.rs b/tls/tls-2pc-aio/src/conn/builder.rs deleted file mode 100644 index 6867db337..000000000 --- a/tls/tls-2pc-aio/src/conn/builder.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::sync::Arc; - -use tls_client::ClientConfig; - -pub struct WantsRole; -pub struct WantsClientConfig {} - -/// ConnectionMaster configuration -pub struct MasterConfig { - pub client: Arc, - pub probe_server: bool, -} - -impl MasterConfig { - pub fn builder() -> ConfigBuilder { - ConfigBuilder::master() - } -} - -pub struct ConfigBuilder { - state: T, -} - -impl ConfigBuilder { - pub fn master() -> ConfigBuilder { - ConfigBuilder { - state: WantsClientConfig {}, - } - } -} - -impl ConfigBuilder { - pub fn client_config(self, config: Arc) -> MasterConfig { - MasterConfig { - client: config, - probe_server: true, - } - } -} diff --git a/tls/tls-2pc-aio/src/conn/master.rs b/tls/tls-2pc-aio/src/conn/master.rs deleted file mode 100644 index 71121ad4a..000000000 --- a/tls/tls-2pc-aio/src/conn/master.rs +++ /dev/null @@ -1,46 +0,0 @@ -use futures::{AsyncRead, AsyncWrite}; -use std::sync::Arc; - -use tls_client::{ClientConnection, ServerName}; - -use super::builder::MasterConfig as Config; -use crate::Error; - -pub enum State { - Initialized, -} - -pub struct ConnectionMaster { - state: State, - client: ClientConnection, - slave_conn: S, -} - -impl ConnectionMaster -where - S: AsyncWrite + AsyncRead, -{ - pub fn new(config: Arc, server_name: ServerName, slave_conn: S) -> Result { - Ok(Self { - state: State::Initialized, - client: ClientConnection::new(config.client.clone(), server_name)?, - slave_conn, - }) - } - - /// Setup all possible 2PC protocols prior to connecting to Server. - /// Probes Server for supported ciphersuites if configured. - pub async fn setup(&mut self) -> Result<(), Error> { - todo!() - } - - /// Runs TLS handshake with Server to completion - pub async fn complete_handshake(&mut self) -> Result<(), Error> { - todo!() - } - - /// Sends application payload to Server - pub async fn send(&mut self, _payload: &[u8]) -> Result<(), Error> { - todo!() - } -} diff --git a/tls/tls-2pc-aio/src/conn/mod.rs b/tls/tls-2pc-aio/src/conn/mod.rs deleted file mode 100644 index d3b743ee0..000000000 --- a/tls/tls-2pc-aio/src/conn/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod builder; -mod master; -mod slave; - -pub use master::ConnectionMaster; -pub use slave::ConnectionSlave; diff --git a/tls/tls-2pc-aio/src/conn/slave.rs b/tls/tls-2pc-aio/src/conn/slave.rs deleted file mode 100644 index 4ae0256fb..000000000 --- a/tls/tls-2pc-aio/src/conn/slave.rs +++ /dev/null @@ -1,14 +0,0 @@ -use futures::{AsyncRead, AsyncWrite}; - -pub struct ConnectionSlave { - master_conn: S, -} - -impl ConnectionSlave -where - S: AsyncWrite + AsyncRead, -{ - pub fn new() -> Self { - todo!() - } -} diff --git a/tls/tls-2pc-aio/src/crypto/master.rs b/tls/tls-2pc-aio/src/crypto/master.rs deleted file mode 100644 index 10b3eb3b3..000000000 --- a/tls/tls-2pc-aio/src/crypto/master.rs +++ /dev/null @@ -1,82 +0,0 @@ -use async_trait::async_trait; -use futures::{AsyncRead, AsyncWrite}; -use tls_client::{Crypto, DecryptMode, EncryptMode, Error, ProtocolVersion, SupportedCipherSuite}; -use tls_core::{ - key::PublicKey, - msgs::{ - handshake::Random, - message::{OpaqueMessage, PlainMessage}, - }, -}; - -/// CryptoMaster implements the TLS Crypto trait using 2PC protocols. -pub struct CryptoMaster { - /// Stream connection to [`CryptoSlave`] - stream: S, -} - -impl CryptoMaster -where - S: AsyncWrite + AsyncRead + Send, -{ - pub fn new(stream: S) -> Self { - Self { stream } - } - - /// Perform setup for 2PC sub-protocols - pub async fn setup(&mut self) { - todo!() - } -} - -#[async_trait] -impl Crypto for CryptoMaster -where - S: AsyncWrite + AsyncRead + Send, -{ - fn select_protocol_version(&mut self, _version: ProtocolVersion) -> Result<(), Error> { - todo!() - } - fn select_cipher_suite(&mut self, _suite: SupportedCipherSuite) -> Result<(), Error> { - todo!() - } - fn suite(&self) -> Result { - todo!() - } - fn set_encrypt(&mut self, _mode: EncryptMode) -> Result<(), Error> { - todo!() - } - fn set_decrypt(&mut self, _mode: DecryptMode) -> Result<(), Error> { - todo!() - } - async fn client_random(&mut self) -> Result { - todo!() - } - async fn client_key_share(&mut self) -> Result { - todo!() - } - async fn set_server_random(&mut self, _random: Random) -> Result<(), Error> { - todo!() - } - async fn set_server_key_share(&mut self, _key: PublicKey) -> Result<(), Error> { - todo!() - } - async fn set_hs_hash_client_key_exchange(&mut self, _hash: &[u8]) -> Result<(), Error> { - todo!() - } - async fn set_hs_hash_server_hello(&mut self, _hash: &[u8]) -> Result<(), Error> { - todo!() - } - async fn server_finished(&mut self, _hash: &[u8]) -> Result, Error> { - todo!() - } - async fn client_finished(&mut self, _hash: &[u8]) -> Result, Error> { - todo!() - } - async fn encrypt(&mut self, _m: PlainMessage, _seq: u64) -> Result { - todo!() - } - async fn decrypt(&mut self, _m: OpaqueMessage, _seq: u64) -> Result { - todo!() - } -} diff --git a/tls/tls-2pc-aio/src/crypto/mod.rs b/tls/tls-2pc-aio/src/crypto/mod.rs deleted file mode 100644 index cb8ea50cf..000000000 --- a/tls/tls-2pc-aio/src/crypto/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod master; -mod slave; diff --git a/tls/tls-2pc-aio/src/crypto/slave.rs b/tls/tls-2pc-aio/src/crypto/slave.rs deleted file mode 100644 index 87f09145c..000000000 --- a/tls/tls-2pc-aio/src/crypto/slave.rs +++ /dev/null @@ -1,26 +0,0 @@ -use futures::{AsyncRead, AsyncWrite}; - -/// HandshakeSlave communicates with [`HandshakeMaster`] over a stream to execute TLS operations in 2PC -pub struct HandshakeSlave { - /// Stream connection to [`HandshakeSlave`] - stream: S, -} - -impl HandshakeSlave -where - S: AsyncWrite + AsyncRead + Send, -{ - pub fn new(stream: S) -> Self { - Self { stream } - } - - /// Perform setup for 2PC sub-protocols - pub async fn setup(&mut self) { - todo!() - } - - /// Receives and processes messages from Master over stream - pub async fn run(&mut self) { - todo!() - } -} diff --git a/tls/tls-2pc-aio/src/error.rs b/tls/tls-2pc-aio/src/error.rs deleted file mode 100644 index 6dc53f11d..000000000 --- a/tls/tls-2pc-aio/src/error.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[derive(thiserror::Error, Debug, Clone, PartialEq)] -pub enum Error { - #[error("Encountered error with ClientConnection: {0:?}")] - ClientConnectionError(#[from] tls_client::Error), - #[error("Encountered error during encryption")] - EncryptError, - #[error("Encountered error during decryption")] - DecryptError, -} diff --git a/tls/tls-2pc-aio/src/lib.rs b/tls/tls-2pc-aio/src/lib.rs deleted file mode 100644 index 0c6a0b4e8..000000000 --- a/tls/tls-2pc-aio/src/lib.rs +++ /dev/null @@ -1,4 +0,0 @@ -// pub mod conn; -// pub mod crypto; -// pub mod error; -// pub mod prf; diff --git a/tls/tls-2pc-aio/src/prf/circuits/c1.rs b/tls/tls-2pc-aio/src/prf/circuits/c1.rs deleted file mode 100644 index 9c2cdfeab..000000000 --- a/tls/tls-2pc-aio/src/prf/circuits/c1.rs +++ /dev/null @@ -1,154 +0,0 @@ -use mpc_aio::protocol::{ - garble::{Execute, GCError}, - point_addition::P256SecretShare, -}; -use rand::{thread_rng, Rng}; -use tls_2pc_core::CIRCUIT_1; - -/// Executes c1 as PRFLeader -/// -/// Returns inner_hash_state -pub async fn leader_c1( - exec: &mut T, - secret_share: P256SecretShare, -) -> Result<[u32; 8], GCError> { - let circ = CIRCUIT_1.clone(); - - let input_pms_share = circ.input_value( - 0, - secret_share - .as_bytes() - .iter() - // convert to little-endian - .rev() - .copied() - .collect::>(), - )?; - - let mask: Vec = thread_rng().gen::<[u8; 32]>().to_vec(); - let input_mask = circ.input_value(2, mask.clone())?; - - let inputs = vec![input_pms_share, input_mask]; - let out = exec.execute(circ, &inputs).await?.decode()?; - - // todo make this less gross - let masked_inner_hash_state = if let mpc_circuits::Value::Bytes(v) = - out.get(0).expect("Circuit 1 should have output 0").value() - { - v - } else { - panic!("Circuit 1 output 0 should be 32 bytes") - }; - - // remove XOR mask and convert to big-endian - let inner_hash_state = masked_inner_hash_state - .iter() - .zip(mask.iter()) - .map(|(v, m)| v ^ m) - .collect::>(); - - let inner_hash_state: [u32; 8] = inner_hash_state - .chunks_exact(4) - .map(|chunk| u32::from_be_bytes([chunk[3], chunk[2], chunk[1], chunk[0]])) - .rev() - .collect::>() - .try_into() - .expect("Circuit 1 output 0 should be 32 bytes"); - - Ok(inner_hash_state) -} - -/// Executes c1 as PRFFollower -/// -/// Returns outer_hash_state -pub async fn follower_c1( - exec: &mut T, - secret_share: P256SecretShare, -) -> Result<[u32; 8], GCError> { - let circ = CIRCUIT_1.clone(); - - let input_pms_share = circ.input_value( - 1, - secret_share - .as_bytes() - .iter() - // convert to little-endian - .rev() - .copied() - .collect::>(), - )?; - - let mask: Vec = thread_rng().gen::<[u8; 32]>().to_vec(); - let input_mask = circ.input_value(3, mask.clone())?; - - let inputs = vec![input_pms_share, input_mask]; - let out = exec.execute(circ, &inputs).await?.decode()?; - - // todo make this less gross - let masked_outer_hash_state = if let mpc_circuits::Value::Bytes(v) = - out.get(1).expect("Circuit 1 should have output 1").value() - { - v - } else { - panic!("Circuit 1 output 1 should be 32 bytes") - }; - - // remove XOR mask and convert to big-endian - let outer_hash_state = masked_outer_hash_state - .iter() - .zip(mask.iter()) - .map(|(v, m)| v ^ m) - .collect::>(); - - let outer_hash_state: [u32; 8] = outer_hash_state - .chunks_exact(4) - .map(|chunk| u32::from_be_bytes([chunk[3], chunk[2], chunk[1], chunk[0]])) - .rev() - .collect::>() - .try_into() - .expect("Circuit 1 output 1 should be 32 bytes"); - - Ok(outer_hash_state) -} - -#[cfg(test)] -mod tests { - use mpc_aio::protocol::garble::exec::dual::mock_dualex_pair; - use tls_2pc_core::prf::sha::partial_sha256_digest; - - use super::*; - - #[ignore = "expensive"] - #[tokio::test] - async fn test_c1() { - let (mut gc_leader, mut gc_follower) = mock_dualex_pair(); - let leader_share = P256SecretShare::new([ - 95, 183, 78, 37, 133, 230, 30, 137, 239, 195, 160, 166, 154, 80, 143, 115, 38, 92, 34, - 169, 61, 96, 130, 40, 42, 129, 231, 68, 109, 244, 150, 193, - ]); - let follower_share = P256SecretShare::new([ - 141, 150, 106, 174, 105, 9, 169, 73, 234, 17, 111, 54, 214, 28, 160, 159, 148, 130, - 223, 55, 134, 50, 172, 164, 63, 158, 46, 149, 197, 226, 90, 29, - ]); - let pms = leader_share + follower_share; - let mut pms_zeropadded = [0u8; 64]; - pms_zeropadded[..32].copy_from_slice(&pms); - let pms_ipad = pms_zeropadded.iter().map(|b| b ^ 0x36).collect::>(); - let pms_opad = pms_zeropadded.iter().map(|b| b ^ 0x5c).collect::>(); - let expected_inner_hash_state = partial_sha256_digest(&pms_ipad); - let expected_outer_hash_state = partial_sha256_digest(&pms_opad); - - let (task_leader, task_follower) = tokio::join!( - tokio::spawn(async move { leader_c1(&mut gc_leader, leader_share).await.unwrap() }), - tokio::spawn( - async move { follower_c1(&mut gc_follower, follower_share).await.unwrap() } - ) - ); - - let inner_hash_state = task_leader.unwrap(); - let outer_hash_state = task_follower.unwrap(); - - assert_eq!(inner_hash_state, expected_inner_hash_state); - assert_eq!(outer_hash_state, expected_outer_hash_state); - } -} diff --git a/tls/tls-2pc-aio/src/prf/circuits/c2.rs b/tls/tls-2pc-aio/src/prf/circuits/c2.rs deleted file mode 100644 index 44c46e673..000000000 --- a/tls/tls-2pc-aio/src/prf/circuits/c2.rs +++ /dev/null @@ -1,105 +0,0 @@ -use mpc_aio::protocol::garble::{Execute, GCError}; -use rand::{thread_rng, Rng}; -use tls_2pc_core::CIRCUIT_2; - -/// Executes c2 as PRFLeader -/// -/// Returns inner_hash_state -pub async fn leader_c2( - exec: &mut T, - p1_inner_hash: [u8; 32], -) -> Result<[u32; 8], GCError> { - let circ = CIRCUIT_2.clone(); - - // convert to little-endian - let input_inner_hash = - circ.input_value(1, p1_inner_hash.iter().rev().copied().collect::>())?; - - let mask: Vec = thread_rng().gen::<[u8; 32]>().to_vec(); - let input_mask = circ.input_value(3, mask.clone())?; - - let inputs = vec![input_inner_hash, input_mask]; - let out = exec.execute(circ, &inputs).await?.decode()?; - - // todo make this less gross - let masked_inner_hash_state = if let mpc_circuits::Value::Bytes(v) = - out.get(0).expect("Circuit 2 should have output 0").value() - { - v - } else { - panic!("Circuit 2 output 0 should be 32 bytes") - }; - - // remove XOR mask and convert to big-endian - let inner_hash_state = masked_inner_hash_state - .iter() - .zip(mask.iter()) - .map(|(v, m)| v ^ m) - .collect::>(); - - let inner_hash_state: [u32; 8] = inner_hash_state - .chunks_exact(4) - .map(|chunk| u32::from_be_bytes([chunk[3], chunk[2], chunk[1], chunk[0]])) - .rev() - .collect::>() - .try_into() - .expect("Circuit 2 output 0 should be 32 bytes"); - - Ok(inner_hash_state) -} - -/// Executes c2 as PRFFollower -/// -/// Returns outer_hash_state -pub async fn follower_c2( - exec: &mut T, - outer_hash_state: [u32; 8], - p2: [u8; 32], -) -> Result<[u32; 8], GCError> { - let circ = CIRCUIT_2.clone(); - - // convert to little-endian - let input_outer_hash_state = circ.input_value( - 0, - outer_hash_state - .into_iter() - .rev() - .map(|v| v.to_le_bytes()) - .flatten() - .collect::>(), - )?; - - let input_p2 = circ.input_value(2, p2[..16].iter().rev().copied().collect::>())?; - - let mask: Vec = thread_rng().gen::<[u8; 32]>().to_vec(); - let input_mask = circ.input_value(4, mask.clone())?; - - let inputs = vec![input_outer_hash_state, input_p2, input_mask]; - let out = exec.execute(circ, &inputs).await?.decode()?; - - // todo make this less gross - let masked_outer_hash_state = if let mpc_circuits::Value::Bytes(v) = - out.get(1).expect("Circuit 2 should have output 1").value() - { - v - } else { - panic!("Circuit 2 output 1 should be 32 bytes") - }; - - // remove XOR mask and convert to big-endian - let outer_hash_state = masked_outer_hash_state - .iter() - .zip(mask.iter()) - .map(|(v, m)| v ^ m) - .collect::>(); - - let outer_hash_state: [u32; 8] = outer_hash_state - .chunks_exact(4) - .map(|chunk| u32::from_be_bytes([chunk[3], chunk[2], chunk[1], chunk[0]])) - .rev() - .collect::>() - .try_into() - .expect("Circuit 2 output 1 should be 32 bytes"); - - Ok(outer_hash_state) -} diff --git a/tls/tls-2pc-aio/src/prf/circuits/c3.rs b/tls/tls-2pc-aio/src/prf/circuits/c3.rs deleted file mode 100644 index 4e617cf8e..000000000 --- a/tls/tls-2pc-aio/src/prf/circuits/c3.rs +++ /dev/null @@ -1,165 +0,0 @@ -use mpc_aio::protocol::garble::{Execute, GCError}; -use rand::{thread_rng, Rng}; -use tls_2pc_core::{SessionKeyShares, CIRCUIT_3}; - -/// Executes c3 as PRFLeader -/// -/// Returns session key shares -pub async fn leader_c3( - exec: &mut T, - p1_inner_hash: [u8; 32], - p2_inner_hash: [u8; 32], -) -> Result { - let circ = CIRCUIT_3.clone(); - - let input_p1_inner_hash = - circ.input_value(5, p1_inner_hash.iter().rev().copied().collect::>())?; - - let input_p2_inner_hash = - circ.input_value(6, p2_inner_hash.iter().rev().copied().collect::>())?; - - let cwk_mask: Vec = thread_rng().gen::<[u8; 16]>().to_vec(); - let input_cwk_mask = circ.input_value(7, cwk_mask.clone())?; - - let swk_mask: Vec = thread_rng().gen::<[u8; 16]>().to_vec(); - let input_swk_mask = circ.input_value(8, swk_mask.clone())?; - - let civ_mask: Vec = thread_rng().gen::<[u8; 4]>().to_vec(); - let input_civ_mask = circ.input_value(9, civ_mask.clone())?; - - let siv_mask: Vec = thread_rng().gen::<[u8; 4]>().to_vec(); - let input_siv_mask = circ.input_value(10, siv_mask.clone())?; - - let inputs = vec![ - input_p1_inner_hash, - input_p2_inner_hash, - input_cwk_mask, - input_swk_mask, - input_civ_mask, - input_siv_mask, - ]; - - let out = exec.execute(circ, &inputs).await?.decode()?; - - // todo make this less gross - let cwk_masked = if let mpc_circuits::Value::Bytes(v) = - out.get(0).expect("Circuit 3 should have output 1").value() - { - v - } else { - panic!("Circuit 3 output 0 should be 16 bytes") - }; - - let swk_masked = if let mpc_circuits::Value::Bytes(v) = - out.get(1).expect("Circuit 3 should have output 0").value() - { - v - } else { - panic!("Circuit 3 output 1 should be 16 bytes") - }; - - let civ_masked = if let mpc_circuits::Value::Bytes(v) = - out.get(2).expect("Circuit 3 should have output 2").value() - { - v - } else { - panic!("Circuit 3 output 2 should be 4 bytes") - }; - - let siv_masked = if let mpc_circuits::Value::Bytes(v) = - out.get(3).expect("Circuit 3 should have output 2").value() - { - v - } else { - panic!("Circuit 3 output 3 should be 4 bytes") - }; - - // Only the leader removes their key masks - let cwk = cwk_masked - .iter() - .zip(cwk_mask.iter()) - .map(|(a, b)| a ^ b) - .rev() - .collect::>(); - let swk = swk_masked - .iter() - .zip(swk_mask.iter()) - .map(|(a, b)| a ^ b) - .rev() - .collect::>(); - let civ = civ_masked - .iter() - .zip(civ_mask.iter()) - .map(|(a, b)| a ^ b) - .rev() - .collect::>(); - let siv = siv_masked - .iter() - .zip(siv_mask.iter()) - .map(|(a, b)| a ^ b) - .rev() - .collect::>(); - - // The leader's key shares are k ⊕ follower_mask - let cwk: [u8; 16] = cwk.try_into().expect("cwk should be 16 bytes"); - let swk: [u8; 16] = swk.try_into().expect("swk should be 16 bytes"); - let civ: [u8; 4] = civ.try_into().expect("civ should be 4 bytes"); - let siv: [u8; 4] = siv.try_into().expect("siv should be 4 bytes"); - - Ok(SessionKeyShares::new(cwk, swk, civ, siv)) -} - -/// Executes c3 as PRFFollower -/// -/// Returns outer_hash_state -pub async fn follower_c3( - exec: &mut T, - outer_hash_state: [u32; 8], -) -> Result { - let circ = CIRCUIT_3.clone(); - - let input_outer_hash_state = circ.input_value( - 0, - outer_hash_state - .into_iter() - .rev() - .map(|v| v.to_le_bytes()) - .flatten() - .collect::>(), - )?; - - let mut cwk_mask: Vec = thread_rng().gen::<[u8; 16]>().to_vec(); - let input_cwk_mask = circ.input_value(1, cwk_mask.clone())?; - - let mut swk_mask: Vec = thread_rng().gen::<[u8; 16]>().to_vec(); - let input_swk_mask = circ.input_value(2, swk_mask.clone())?; - - let mut civ_mask: Vec = thread_rng().gen::<[u8; 4]>().to_vec(); - let input_civ_mask = circ.input_value(3, civ_mask.clone())?; - - let mut siv_mask: Vec = thread_rng().gen::<[u8; 4]>().to_vec(); - let input_siv_mask = circ.input_value(4, siv_mask.clone())?; - - let inputs = vec![ - input_outer_hash_state, - input_cwk_mask, - input_swk_mask, - input_civ_mask, - input_siv_mask, - ]; - - let _ = exec.execute(circ, &inputs).await?.decode()?; - - cwk_mask.reverse(); - swk_mask.reverse(); - civ_mask.reverse(); - siv_mask.reverse(); - - // The followers's key shares are just their masks - let cwk: [u8; 16] = cwk_mask.try_into().expect("cwk should be 16 bytes"); - let swk: [u8; 16] = swk_mask.try_into().expect("swk should be 16 bytes"); - let civ: [u8; 4] = civ_mask.try_into().expect("civ should be 4 bytes"); - let siv: [u8; 4] = siv_mask.try_into().expect("siv should be 4 bytes"); - - Ok(SessionKeyShares::new(cwk, swk, civ, siv)) -} diff --git a/tls/tls-2pc-aio/src/prf/circuits/mod.rs b/tls/tls-2pc-aio/src/prf/circuits/mod.rs deleted file mode 100644 index 12354a886..000000000 --- a/tls/tls-2pc-aio/src/prf/circuits/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod c1; -mod c2; -mod c3; - -pub use c1::*; -pub use c2::*; -pub use c3::*; diff --git a/tls/tls-2pc-aio/src/prf/follower.rs b/tls/tls-2pc-aio/src/prf/follower.rs deleted file mode 100644 index 7b38f47f2..000000000 --- a/tls/tls-2pc-aio/src/prf/follower.rs +++ /dev/null @@ -1,168 +0,0 @@ -use super::{circuits, PRFChannel, PRFError}; -use futures::{SinkExt, StreamExt}; -use mpc_aio::protocol::{garble::Execute, point_addition::P256SecretShare}; -use tls_2pc_core::{ - prf::{self as core, follower_state as state, PRFMessage}, - SessionKeyShares, -}; -use utils_aio::expect_msg_or_err; - -pub struct MasterSecret { - core: core::PRFFollower, -} - -pub struct ClientFinished { - core: core::PRFFollower, -} - -pub struct ServerFinished { - core: core::PRFFollower, -} - -pub struct PRFFollower -where - G: Execute + Send, -{ - state: S, - channel: PRFChannel, - gc_exec: G, -} - -impl PRFFollower -where - G: Execute + Send, -{ - pub fn new(channel: PRFChannel, gc_exec: G) -> PRFFollower { - PRFFollower { - state: MasterSecret { - core: core::PRFFollower::new(), - }, - channel, - gc_exec, - } - } - - pub async fn compute_session_keys( - mut self, - secret_share: P256SecretShare, - ) -> Result<(SessionKeyShares, PRFFollower), PRFError> { - let outer_hash_state = circuits::follower_c1(&mut self.gc_exec, secret_share).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderMs1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = self.state.core.next(outer_hash_state, msg); - - self.channel.send(PRFMessage::FollowerMs1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderMs2, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::FollowerMs2(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderMs3, - PRFError::UnexpectedMessage - )?; - let core = core.next(msg); - - let p2 = core.p2(); - let outer_hash_state = - circuits::follower_c2(&mut self.gc_exec, outer_hash_state, p2).await?; - - let core = core.next().next(outer_hash_state); - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderKe1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::FollowerKe1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderKe2, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::FollowerKe2(msg)).await?; - - let session_keys = circuits::follower_c3(&mut self.gc_exec, outer_hash_state).await?; - - Ok(( - session_keys, - PRFFollower { - state: ClientFinished { core }, - channel: self.channel, - gc_exec: self.gc_exec, - }, - )) - } -} - -impl PRFFollower -where - G: Execute + Send, -{ - /// Computes client finished data using handshake hash - /// - /// Returns next state - pub async fn compute_client_finished( - mut self, - ) -> Result, PRFError> { - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderCf1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = self.state.core.next(msg); - self.channel.send(PRFMessage::FollowerCf1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderCf2, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::FollowerCf2(msg)).await?; - - Ok(PRFFollower { - state: ServerFinished { core }, - channel: self.channel, - gc_exec: self.gc_exec, - }) - } -} - -impl PRFFollower -where - G: Execute + Send, -{ - /// Computes server finished data using handshake hash - pub async fn compute_server_finished(mut self) -> Result<(), PRFError> { - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderSf1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = self.state.core.next(msg); - self.channel.send(PRFMessage::FollowerSf1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::LeaderSf2, - PRFError::UnexpectedMessage - )?; - let msg = core.next(msg); - self.channel.send(PRFMessage::FollowerSf2(msg)).await?; - - Ok(()) - } -} diff --git a/tls/tls-2pc-aio/src/prf/leader.rs b/tls/tls-2pc-aio/src/prf/leader.rs deleted file mode 100644 index 963e14205..000000000 --- a/tls/tls-2pc-aio/src/prf/leader.rs +++ /dev/null @@ -1,184 +0,0 @@ -use super::{circuits, PRFChannel, PRFError}; -use futures::{SinkExt, StreamExt}; -use mpc_aio::protocol::{garble::Execute, point_addition::P256SecretShare}; -use tls_2pc_core::{ - prf::{self as core, leader_state as state, PRFMessage}, - SessionKeyShares, -}; -use utils_aio::expect_msg_or_err; - -pub struct MasterSecret { - core: core::PRFLeader, -} - -pub struct ClientFinished { - core: core::PRFLeader, -} - -pub struct ServerFinished { - core: core::PRFLeader, -} - -pub struct PRFLeader -where - G: Execute + Send, -{ - state: S, - channel: PRFChannel, - gc_exec: G, -} - -impl PRFLeader -where - G: Execute + Send, -{ - pub fn new(channel: PRFChannel, gc_exec: G) -> PRFLeader { - PRFLeader { - state: MasterSecret { - core: core::PRFLeader::new(), - }, - channel, - gc_exec, - } - } - - /// Computes leader's shares of the TLS session keys using the session randoms and their - /// share of the PMS - /// - /// Returns session key shares - pub async fn compute_session_keys( - mut self, - client_random: [u8; 32], - server_random: [u8; 32], - secret_share: P256SecretShare, - ) -> Result<(SessionKeyShares, PRFLeader), PRFError> { - let inner_hash_state = circuits::leader_c1(&mut self.gc_exec, secret_share).await?; - let (msg, core) = self - .state - .core - .next(client_random, server_random, inner_hash_state); - - self.channel.send(PRFMessage::LeaderMs1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerMs1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::LeaderMs2(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerMs2, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::LeaderMs3(msg)).await?; - - let p1_inner_hash = core.p1_inner_hash(); - let inner_hash_state = circuits::leader_c2(&mut self.gc_exec, p1_inner_hash).await?; - - let (msg, core) = core.next().next(inner_hash_state); - self.channel.send(PRFMessage::LeaderKe1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerKe1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::LeaderKe2(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerKe2, - PRFError::UnexpectedMessage - )?; - let core = core.next(msg); - let p1_inner_hash = core.p1_inner_hash(); - let p2_inner_hash = core.p2_inner_hash(); - - let session_keys = - circuits::leader_c3(&mut self.gc_exec, p1_inner_hash, p2_inner_hash).await?; - - Ok(( - session_keys, - PRFLeader { - state: ClientFinished { core: core.next() }, - channel: self.channel, - gc_exec: self.gc_exec, - }, - )) - } -} - -impl PRFLeader -where - G: Execute + Send, -{ - /// Computes client finished data using handshake hash - /// - /// Returns client finished data and next state - pub async fn compute_client_finished( - mut self, - hash: &[u8], - ) -> Result<([u8; 12], PRFLeader), PRFError> { - let (msg, core) = self.state.core.next(hash); - self.channel.send(PRFMessage::LeaderCf1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerCf1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::LeaderCf2(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerCf2, - PRFError::UnexpectedMessage - )?; - let (vd, core) = core.next(msg); - - Ok(( - vd, - PRFLeader { - state: ServerFinished { core }, - channel: self.channel, - gc_exec: self.gc_exec, - }, - )) - } -} - -impl PRFLeader -where - G: Execute + Send, -{ - /// Computes server finished data using handshake hash - /// - /// Returns server finished data - pub async fn compute_server_finished(mut self, hash: &[u8]) -> Result<[u8; 12], PRFError> { - let (msg, core) = self.state.core.next(hash); - self.channel.send(PRFMessage::LeaderSf1(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerSf1, - PRFError::UnexpectedMessage - )?; - let (msg, core) = core.next(msg); - self.channel.send(PRFMessage::LeaderSf2(msg)).await?; - - let msg = expect_msg_or_err!( - self.channel.next().await, - PRFMessage::FollowerSf2, - PRFError::UnexpectedMessage - )?; - let vd = core.next(msg); - - Ok(vd) - } -} diff --git a/tls/tls-2pc-aio/src/prf/mod.rs b/tls/tls-2pc-aio/src/prf/mod.rs deleted file mode 100644 index e15f78d86..000000000 --- a/tls/tls-2pc-aio/src/prf/mod.rs +++ /dev/null @@ -1,182 +0,0 @@ -mod circuits; -mod follower; -mod leader; - -use mpc_aio::protocol::garble::GCError; -use utils_aio::Channel; - -pub use follower::PRFFollower; -pub use leader::PRFLeader; -pub use tls_2pc_core::msgs::prf::PRFMessage; - -pub type PRFChannel = Box>; - -#[derive(Debug, thiserror::Error)] -pub enum PRFError { - #[error("error occurred during garbled circuit protocol")] - GCError(#[from] GCError), - #[error("io error")] - IOError(#[from] std::io::Error), - #[error("unexpected message: {0:?}")] - UnexpectedMessage(PRFMessage), -} - -#[cfg(test)] -mod tests { - use mpc_aio::protocol::{ - garble::exec::dual::mock_dualex_pair, - point_addition::{P256SecretShare, PaillierFollower, PaillierLeader, PointAddition2PC}, - }; - use p256::{elliptic_curve::sec1::ToEncodedPoint, SecretKey}; - use rand::{thread_rng, Rng}; - use tls_2pc_core::prf::utils::{hmac_sha256, seed_ke, seed_ms}; - use utils_aio::duplex::DuplexChannel; - - use super::*; - - async fn get_shares() -> (P256SecretShare, P256SecretShare) { - let (leader_channel, follower_channel) = DuplexChannel::new(); - - let mut leader = PaillierLeader::new(Box::new(leader_channel)); - let mut follower = PaillierFollower::new(Box::new(follower_channel)); - - let mut rng = thread_rng(); - - let server_secret = SecretKey::random(&mut rng); - let server_pk = server_secret.public_key().to_projective(); - - let leader_secret = SecretKey::random(&mut rng); - let leader_point = - (&server_pk * &leader_secret.to_nonzero_scalar()).to_encoded_point(false); - - let follower_secret = SecretKey::random(&mut rng); - let follower_point = - (&server_pk * &follower_secret.to_nonzero_scalar()).to_encoded_point(false); - - let (task_m, task_s) = tokio::join!( - tokio::spawn(async move { leader.add(&leader_point).await }), - tokio::spawn(async move { follower.add(&follower_point).await }) - ); - - let leader_share = task_m.unwrap().unwrap(); - let follower_share = task_s.unwrap().unwrap(); - - (leader_share, follower_share) - } - - /// Expands pre-master secret into session key using TLS 1.2 PRF - /// Returns session keys - pub fn key_expansion_tls12( - client_random: &[u8; 32], - server_random: &[u8; 32], - pms: &[u8], - ) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4]) { - // first expand pms into ms - let seed = seed_ms(client_random, server_random); - let a1 = hmac_sha256(pms, &seed); - let a2 = hmac_sha256(pms, &a1); - let mut a1_seed = [0u8; 109]; - a1_seed[..32].copy_from_slice(&a1); - a1_seed[32..].copy_from_slice(&seed); - let mut a2_seed = [0u8; 109]; - a2_seed[..32].copy_from_slice(&a2); - a2_seed[32..].copy_from_slice(&seed); - let p1 = hmac_sha256(pms, &a1_seed); - let p2 = hmac_sha256(pms, &a2_seed); - let mut ms = [0u8; 48]; - ms[..32].copy_from_slice(&p1); - ms[32..].copy_from_slice(&p2[..16]); - - // expand ms into session keys - let seed = seed_ke(client_random, server_random); - let a1 = hmac_sha256(&ms, &seed); - let a2 = hmac_sha256(&ms, &a1); - let mut a1_seed = [0u8; 109]; - a1_seed[..32].copy_from_slice(&a1); - a1_seed[32..].copy_from_slice(&seed); - let mut a2_seed = [0u8; 109]; - a2_seed[..32].copy_from_slice(&a2); - a2_seed[32..].copy_from_slice(&seed); - let p1 = hmac_sha256(&ms, &a1_seed); - let p2 = hmac_sha256(&ms, &a2_seed); - let mut ek = [0u8; 40]; - ek[..32].copy_from_slice(&p1); - ek[32..].copy_from_slice(&p2[..8]); - - let mut cwk = [0u8; 16]; - cwk.copy_from_slice(&ek[..16]); - let mut swk = [0u8; 16]; - swk.copy_from_slice(&ek[16..32]); - let mut civ = [0u8; 4]; - civ.copy_from_slice(&ek[32..36]); - let mut siv = [0u8; 4]; - siv.copy_from_slice(&ek[36..]); - - (cwk, swk, civ, siv) - } - - #[ignore = "expensive"] - #[tokio::test] - async fn test_prf() { - let (leader_channel, follower_channel) = DuplexChannel::::new(); - let (gc_leader, gc_follower) = mock_dualex_pair(); - let leader = PRFLeader::new(Box::new(leader_channel), gc_leader); - let follower = PRFFollower::new(Box::new(follower_channel), gc_follower); - - let client_random: [u8; 32] = thread_rng().gen(); - let server_random: [u8; 32] = thread_rng().gen(); - - let (leader_share, follower_share) = get_shares().await; - - let pms = leader_share + follower_share; - - let (task_leader, task_follower) = tokio::join!( - tokio::task::spawn_blocking(move || { - futures::executor::block_on(leader.compute_session_keys( - client_random, - server_random, - leader_share, - )) - }), - tokio::task::spawn_blocking(move || { - futures::executor::block_on(follower.compute_session_keys(follower_share)) - }) - ); - - let (leader_keys, _leader) = task_leader.unwrap().unwrap(); - let (follower_keys, _follower) = task_follower.unwrap().unwrap(); - - let cwk = leader_keys - .cwk() - .iter() - .zip(follower_keys.cwk()) - .map(|(a, b)| a ^ b) - .collect::>(); - let swk = leader_keys - .swk() - .iter() - .zip(follower_keys.swk()) - .map(|(a, b)| a ^ b) - .collect::>(); - let civ = leader_keys - .civ() - .iter() - .zip(follower_keys.civ()) - .map(|(a, b)| a ^ b) - .collect::>(); - let siv = leader_keys - .siv() - .iter() - .zip(follower_keys.siv()) - .map(|(a, b)| a ^ b) - .collect::>(); - - let (expected_cwk, expected_swk, expected_civ, expected_siv) = - key_expansion_tls12(&client_random, &server_random, &pms); - - assert_eq!(cwk, expected_cwk); - assert_eq!(swk, expected_swk); - assert_eq!(civ, expected_civ); - assert_eq!(siv, expected_siv); - } -} diff --git a/tls/tls-2pc-core/Cargo.toml b/tls/tls-2pc-core/Cargo.toml deleted file mode 100644 index 69ab59aac..000000000 --- a/tls/tls-2pc-core/Cargo.toml +++ /dev/null @@ -1,64 +0,0 @@ -[package] -name = "tlsn-tls-2pc-core" -version = "0.1.0" -edition = "2021" - -[lib] -name = "tls_2pc_core" - -[features] -default = ["prf", "circuits"] -prf = [] -circuits = ["c1", "c2", "c3", "c4", "c5", "c6", "c7"] -c1 = [] -c2 = [] -c3 = [] -c4 = [] -c5 = [] -c6 = [] -c7 = [] -build-circuits = [] -serde = ["dep:serde"] - -[dependencies] -tlsn-mpc-circuits.workspace = true -tlsn-mpc-core.workspace = true -tlsn-utils.workspace = true -share-conversion-core.workspace = true - -sha2 = { workspace = true, features = ["compress"] } -digest.workspace = true -hmac.workspace = true -elliptic-curve.workspace = true -rand.workspace = true -thiserror.workspace = true -serde = { workspace = true, features = ["derive"], optional = true } -once_cell.workspace = true - -[dev-dependencies] -criterion.workspace = true -rand_chacha.workspace = true -hex.workspace = true -num = { workspace = true, features = ["rand"] } -aes = { workspace = true, features = [] } - -[build-dependencies] -tlsn-mpc-circuits.workspace = true -tlsn-tls-circuits = { path = "../tls-circuits" } -regex.workspace = true -prost.workspace = true -rayon.workspace = true - -[[test]] -# don't run the heavy circuit_test unless explicitely invoked with -# cargo test --test circuit_test -name = "circuit_test" -test = false - -[[bench]] -name = "garble" -harness = false - -[[bench]] -name = "circuit" -harness = false diff --git a/tls/tls-2pc-core/benches/circuit.rs b/tls/tls-2pc-core/benches/circuit.rs deleted file mode 100644 index f0d51b68c..000000000 --- a/tls/tls-2pc-core/benches/circuit.rs +++ /dev/null @@ -1,33 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use mpc_circuits::Circuit; - -static CIRCUITS: &[&[u8]] = &[ - #[cfg(feature = "c1")] - tls_2pc_core::CIRCUIT_1_BYTES, - #[cfg(feature = "c2")] - tls_2pc_core::CIRCUIT_2_BYTES, - #[cfg(feature = "c3")] - tls_2pc_core::CIRCUIT_3_BYTES, - #[cfg(feature = "c4")] - tls_2pc_core::CIRCUIT_4_BYTES, - #[cfg(feature = "c5")] - tls_2pc_core::CIRCUIT_5_BYTES, - #[cfg(feature = "c6")] - tls_2pc_core::CIRCUIT_6_BYTES, - #[cfg(feature = "c7")] - tls_2pc_core::CIRCUIT_7_BYTES, -]; - -fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("load_circuits"); - - for circ_bytes in CIRCUITS { - let circ = Circuit::load_bytes(circ_bytes).unwrap(); - group.bench_function(circ.description(), |b| { - b.iter(|| black_box(Circuit::load_bytes(circ_bytes).unwrap())) - }); - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/tls/tls-2pc-core/benches/garble.rs b/tls/tls-2pc-core/benches/garble.rs deleted file mode 100644 index 151600b71..000000000 --- a/tls/tls-2pc-core/benches/garble.rs +++ /dev/null @@ -1,46 +0,0 @@ -use aes::{cipher::NewBlockCipher, Aes128}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use mpc_circuits::Circuit; -use mpc_core::garble::{FullInputLabels, GarbledCircuit}; -use rand::thread_rng; - -static CIRCUITS: &[&[u8]] = &[ - #[cfg(feature = "c1")] - tls_2pc_core::CIRCUIT_1_BYTES, - #[cfg(feature = "c2")] - tls_2pc_core::CIRCUIT_2_BYTES, - #[cfg(feature = "c3")] - tls_2pc_core::CIRCUIT_3_BYTES, - #[cfg(feature = "c4")] - tls_2pc_core::CIRCUIT_4_BYTES, - #[cfg(feature = "c5")] - tls_2pc_core::CIRCUIT_5_BYTES, - #[cfg(feature = "c6")] - tls_2pc_core::CIRCUIT_6_BYTES, - #[cfg(feature = "c7")] - tls_2pc_core::CIRCUIT_7_BYTES, -]; - -fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("garble_circuits"); - - for circ in CIRCUITS { - let circ = Circuit::load_bytes(circ).unwrap(); - group.bench_function(circ.description(), |b| { - let mut rng = thread_rng(); - let cipher = Aes128::new_from_slice(&[0u8; 16]).unwrap(); - let (labels, delta) = FullInputLabels::generate_set(&mut rng, &circ, None); - b.iter(|| { - black_box(GarbledCircuit::generate( - &cipher, - circ.clone(), - delta, - &labels, - )) - }) - }); - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/tls/tls-2pc-core/circuits/bin/c1.bin b/tls/tls-2pc-core/circuits/bin/c1.bin deleted file mode 100644 index 44a94e93e..000000000 Binary files a/tls/tls-2pc-core/circuits/bin/c1.bin and /dev/null differ diff --git a/tls/tls-2pc-core/circuits/bin/c2.bin b/tls/tls-2pc-core/circuits/bin/c2.bin deleted file mode 100644 index 9088487f0..000000000 Binary files a/tls/tls-2pc-core/circuits/bin/c2.bin and /dev/null differ diff --git a/tls/tls-2pc-core/circuits/bin/c3.bin b/tls/tls-2pc-core/circuits/bin/c3.bin deleted file mode 100644 index 8a920a590..000000000 Binary files a/tls/tls-2pc-core/circuits/bin/c3.bin and /dev/null differ diff --git a/tls/tls-2pc-core/circuits/bin/c4.bin b/tls/tls-2pc-core/circuits/bin/c4.bin deleted file mode 100644 index 7d03322c6..000000000 Binary files a/tls/tls-2pc-core/circuits/bin/c4.bin and /dev/null differ diff --git a/tls/tls-2pc-core/circuits/bin/c5.bin b/tls/tls-2pc-core/circuits/bin/c5.bin deleted file mode 100644 index a56122952..000000000 Binary files a/tls/tls-2pc-core/circuits/bin/c5.bin and /dev/null differ diff --git a/tls/tls-2pc-core/circuits/bin/c6.bin b/tls/tls-2pc-core/circuits/bin/c6.bin deleted file mode 100644 index 34e9260d1..000000000 Binary files a/tls/tls-2pc-core/circuits/bin/c6.bin and /dev/null differ diff --git a/tls/tls-2pc-core/circuits/bin/c7.bin b/tls/tls-2pc-core/circuits/bin/c7.bin deleted file mode 100644 index 5fc341d5d..000000000 Binary files a/tls/tls-2pc-core/circuits/bin/c7.bin and /dev/null differ diff --git a/tls/tls-2pc-core/src/lib.rs b/tls/tls-2pc-core/src/lib.rs deleted file mode 100644 index 0b4681cc4..000000000 --- a/tls/tls-2pc-core/src/lib.rs +++ /dev/null @@ -1,79 +0,0 @@ -pub mod msgs; -#[cfg(feature = "prf")] -pub mod prf; - -pub use mpc_circuits::{Circuit, CircuitError}; - -use once_cell::sync::Lazy; -use std::sync::Arc; - -#[cfg(feature = "c1")] -pub static CIRCUIT_1_BYTES: &[u8] = std::include_bytes!("../circuits/bin/c1.bin"); -#[cfg(feature = "c2")] -pub static CIRCUIT_2_BYTES: &[u8] = std::include_bytes!("../circuits/bin/c2.bin"); -#[cfg(feature = "c3")] -pub static CIRCUIT_3_BYTES: &[u8] = std::include_bytes!("../circuits/bin/c3.bin"); -#[cfg(feature = "c4")] -pub static CIRCUIT_4_BYTES: &[u8] = std::include_bytes!("../circuits/bin/c4.bin"); -#[cfg(feature = "c5")] -pub static CIRCUIT_5_BYTES: &[u8] = std::include_bytes!("../circuits/bin/c5.bin"); -#[cfg(feature = "c6")] -pub static CIRCUIT_6_BYTES: &[u8] = std::include_bytes!("../circuits/bin/c6.bin"); -#[cfg(feature = "c7")] -pub static CIRCUIT_7_BYTES: &[u8] = std::include_bytes!("../circuits/bin/c7.bin"); - -#[cfg(feature = "c1")] -pub static CIRCUIT_1: Lazy> = - Lazy::new(|| Circuit::load_bytes(CIRCUIT_1_BYTES).unwrap()); -#[cfg(feature = "c2")] -pub static CIRCUIT_2: Lazy> = - Lazy::new(|| Circuit::load_bytes(CIRCUIT_2_BYTES).unwrap()); -#[cfg(feature = "c3")] -pub static CIRCUIT_3: Lazy> = - Lazy::new(|| Circuit::load_bytes(CIRCUIT_3_BYTES).unwrap()); -#[cfg(feature = "c4")] -pub static CIRCUIT_4: Lazy> = - Lazy::new(|| Circuit::load_bytes(CIRCUIT_4_BYTES).unwrap()); -#[cfg(feature = "c5")] -pub static CIRCUIT_5: Lazy> = - Lazy::new(|| Circuit::load_bytes(CIRCUIT_5_BYTES).unwrap()); -#[cfg(feature = "c6")] -pub static CIRCUIT_6: Lazy> = - Lazy::new(|| Circuit::load_bytes(CIRCUIT_6_BYTES).unwrap()); -#[cfg(feature = "c7")] -pub static CIRCUIT_7: Lazy> = - Lazy::new(|| Circuit::load_bytes(CIRCUIT_7_BYTES).unwrap()); - -pub struct SessionKeyShares { - cwk: [u8; 16], - swk: [u8; 16], - civ: [u8; 4], - siv: [u8; 4], -} - -impl SessionKeyShares { - /// Creates new SessionKeyShares - pub fn new(cwk: [u8; 16], swk: [u8; 16], civ: [u8; 4], siv: [u8; 4]) -> Self { - Self { cwk, swk, civ, siv } - } - - /// Returns client_write_key share - pub fn cwk(&self) -> [u8; 16] { - self.cwk - } - - /// Returns server_write_key share - pub fn swk(&self) -> [u8; 16] { - self.swk - } - - /// Returns client IV share - pub fn civ(&self) -> [u8; 4] { - self.civ - } - - /// Returns server IV share - pub fn siv(&self) -> [u8; 4] { - self.siv - } -} diff --git a/tls/tls-2pc-core/src/msgs/mod.rs b/tls/tls-2pc-core/src/msgs/mod.rs deleted file mode 100644 index a2facbb5d..000000000 --- a/tls/tls-2pc-core/src/msgs/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[cfg(feature = "prf")] -pub mod prf; diff --git a/tls/tls-2pc-core/src/msgs/prf.rs b/tls/tls-2pc-core/src/msgs/prf.rs deleted file mode 100644 index ab8f2e255..000000000 --- a/tls/tls-2pc-core/src/msgs/prf.rs +++ /dev/null @@ -1,149 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum PRFMessage { - LeaderMs1(LeaderMs1), - FollowerMs1(FollowerMs1), - LeaderMs2(LeaderMs2), - FollowerMs2(FollowerMs2), - LeaderMs3(LeaderMs3), - FollowerMs3(FollowerMs3), - LeaderKe1(LeaderKe1), - FollowerKe1(FollowerKe1), - LeaderKe2(LeaderKe2), - FollowerKe2(FollowerKe2), - LeaderCf1(LeaderCf1), - FollowerCf1(FollowerCf1), - LeaderCf2(LeaderCf2), - FollowerCf2(FollowerCf2), - LeaderSf1(LeaderSf1), - FollowerSf1(FollowerSf1), - LeaderSf2(LeaderSf2), - FollowerSf2(FollowerSf2), -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderMs1 { - /// H((pms xor ipad) || seed) - pub a1_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderMs2 { - /// H((pms xor ipad) || a1) - pub a2_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderMs3 { - /// H((pms xor ipad) || a2) - pub p2_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderKe1 { - /// H((ms xor ipad) || seed) - pub a1_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderKe2 { - /// H((ms xor ipad) || a1) - pub a2_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderCf1 { - /// H((ms xor ipad) || seed) - pub a1_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderCf2 { - /// H((ms xor ipad) || a1 || seed) - pub p1_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderSf1 { - /// H((ms xor ipad) || seed) - pub a1_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct LeaderSf2 { - /// H((ms xor ipad) || a1 || seed) - pub sf_vd_inner_hash: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerMs1 { - /// H((pms xor opad) || H((pms xor ipad) || seed)) - pub a1: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerMs2 { - /// H((pms xor opad) || H((pms xor ipad) || a1)) - pub a2: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerMs3 { - /// H((pms xor opad) || H((pms xor ipad) || a2 || seed)) - pub p2: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerKe1 { - /// H((ms xor opad) || H((ms xor ipad) || seed)) - pub a1: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerKe2 { - /// H((ms xor opad) || H((ms xor ipad) || a1)) - pub a2: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerCf1 { - /// H((ms xor opad) || H((ms xor ipad) || seed)) - pub a1: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerCf2 { - pub verify_data: [u8; 12], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerSf1 { - /// H((ms xor opad) || H((ms xor ipad) || seed)) - pub a1: [u8; 32], -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FollowerSf2 { - pub verify_data: [u8; 12], -} diff --git a/tls/tls-2pc-core/src/prf/follower.rs b/tls/tls-2pc-core/src/prf/follower.rs deleted file mode 100644 index 258234f4b..000000000 --- a/tls/tls-2pc-core/src/prf/follower.rs +++ /dev/null @@ -1,275 +0,0 @@ -use super::sha::finalize_sha256_digest; -use crate::msgs::prf as msgs; - -pub mod state { - mod sealed { - pub trait Sealed {} - impl Sealed for super::Ms1 {} - impl Sealed for super::Ms2 {} - impl Sealed for super::Ms3 {} - impl Sealed for super::MsComplete {} - impl Sealed for super::Ke1 {} - impl Sealed for super::Ke2 {} - impl Sealed for super::Ke3 {} - impl Sealed for super::Cf1 {} - impl Sealed for super::Cf2 {} - impl Sealed for super::Sf1 {} - impl Sealed for super::Sf2 {} - } - - pub trait State: sealed::Sealed {} - - pub struct Ms1 {} - pub struct Ms2 { - pub(super) outer_hash_state: [u32; 8], - } - pub struct Ms3 { - pub(super) outer_hash_state: [u32; 8], - } - pub struct MsComplete { - pub(super) p2: [u8; 32], - } - pub struct Ke1 {} - pub struct Ke2 { - pub(super) outer_hash_state: [u32; 8], - } - pub struct Ke3 { - pub(super) outer_hash_state: [u32; 8], - } - pub struct Cf1 { - pub(super) outer_hash_state: [u32; 8], - } - pub struct Cf2 { - pub(super) outer_hash_state: [u32; 8], - } - pub struct Sf1 { - pub(super) outer_hash_state: [u32; 8], - } - pub struct Sf2 { - pub(super) outer_hash_state: [u32; 8], - } - - impl State for Ms1 {} - impl State for Ms2 {} - impl State for Ms3 {} - impl State for MsComplete {} - impl State for Ke1 {} - impl State for Ke2 {} - impl State for Ke3 {} - impl State for Cf1 {} - impl State for Cf2 {} - impl State for Sf1 {} - impl State for Sf2 {} -} - -use state::*; - -pub struct PRFFollower { - state: S, -} - -impl PRFFollower { - /// Returns new PRF follower - pub fn new() -> PRFFollower { - PRFFollower { state: Ms1 {} } - } - - /// Computes a1 - /// ```text - /// H((pms xor opad) || H((pms xor ipad) || seed)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next( - self, - outer_hash_state: [u32; 8], - msg: msgs::LeaderMs1, - ) -> (msgs::FollowerMs1, PRFFollower) { - ( - msgs::FollowerMs1 { - a1: finalize_sha256_digest(outer_hash_state, 64, &msg.a1_inner_hash), - }, - PRFFollower { - state: Ms2 { outer_hash_state }, - }, - ) - } -} - -impl Default for PRFFollower { - fn default() -> Self { - Self::new() - } -} - -impl PRFFollower { - /// Computes a2 - /// ```text - /// H((pms xor opad) || H((pms xor ipad) || a1)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next(self, msg: msgs::LeaderMs2) -> (msgs::FollowerMs2, PRFFollower) { - ( - msgs::FollowerMs2 { - a2: finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.a2_inner_hash), - }, - PRFFollower { - state: Ms3 { - outer_hash_state: self.state.outer_hash_state, - }, - }, - ) - } -} - -impl PRFFollower { - /// Computes p2 - /// ```text - /// H((pms xor opad) || H((pms xor ipad) || a2 || seed)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next(self, msg: msgs::LeaderMs3) -> PRFFollower { - let p2 = finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.p2_inner_hash); - PRFFollower { - state: MsComplete { p2 }, - } - } -} - -impl PRFFollower { - /// Returns master secret p2 - /// ```text - /// H((pms xor opad) || H((pms xor ipad) || a2 || seed)) - /// ``` - pub fn p2(&self) -> [u8; 32] { - self.state.p2 - } - - /// Returns next state - pub fn next(self) -> PRFFollower { - PRFFollower { state: Ke1 {} } - } -} - -impl PRFFollower { - /// Returns next state - pub fn next(self, outer_hash_state: [u32; 8]) -> PRFFollower { - PRFFollower { - state: Ke2 { outer_hash_state }, - } - } -} - -impl PRFFollower { - /// Computes a1 - /// ```text - /// H((ms xor opad) || H((ms xor ipad) || seed)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next(self, msg: msgs::LeaderKe1) -> (msgs::FollowerKe1, PRFFollower) { - ( - msgs::FollowerKe1 { - a1: finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.a1_inner_hash), - }, - PRFFollower { - state: Ke3 { - outer_hash_state: self.state.outer_hash_state, - }, - }, - ) - } -} - -impl PRFFollower { - /// Computes a2 - /// ```text - /// H((ms xor opad) || H((ms xor ipad) || a1)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next(self, msg: msgs::LeaderKe2) -> (msgs::FollowerKe2, PRFFollower) { - ( - msgs::FollowerKe2 { - a2: finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.a2_inner_hash), - }, - PRFFollower { - state: Cf1 { - outer_hash_state: self.state.outer_hash_state, - }, - }, - ) - } -} - -impl PRFFollower { - /// Computes a1 - /// ```text - /// H((ms xor opad) || H((ms xor ipad) || cf_seed)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next(self, msg: msgs::LeaderCf1) -> (msgs::FollowerCf1, PRFFollower) { - ( - msgs::FollowerCf1 { - a1: finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.a1_inner_hash), - }, - PRFFollower { - state: Cf2 { - outer_hash_state: self.state.outer_hash_state, - }, - }, - ) - } -} - -impl PRFFollower { - /// Computes client finished verify_data - /// ```text - /// H((ms xor opad) || H((ms xor ipad) || a1 || cf_seed)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next(self, msg: msgs::LeaderCf2) -> (msgs::FollowerCf2, PRFFollower) { - let p1 = finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.p1_inner_hash); - let mut verify_data = [0u8; 12]; - verify_data.copy_from_slice(&p1[..12]); - ( - msgs::FollowerCf2 { verify_data }, - PRFFollower { - state: Sf1 { - outer_hash_state: self.state.outer_hash_state, - }, - }, - ) - } -} - -impl PRFFollower { - /// Computes a1 - /// ```text - /// H((ms xor opad) || H((ms xor ipad) || sf_seed)) - /// ``` - /// Returns message to [`super::PRFLeader`] and next state - pub fn next(self, msg: msgs::LeaderSf1) -> (msgs::FollowerSf1, PRFFollower) { - ( - msgs::FollowerSf1 { - a1: finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.a1_inner_hash), - }, - PRFFollower { - state: Sf2 { - outer_hash_state: self.state.outer_hash_state, - }, - }, - ) - } -} - -impl PRFFollower { - /// Computes server finished verify_data - /// ```text - /// H((ms xor opad) || H((ms xor ipad) || a1 || sf_seed)) - /// ``` - /// Returns message to [`super::PRFLeader`] - pub fn next(self, msg: msgs::LeaderSf2) -> msgs::FollowerSf2 { - let p1 = finalize_sha256_digest(self.state.outer_hash_state, 64, &msg.sf_vd_inner_hash); - let mut verify_data = [0u8; 12]; - verify_data.copy_from_slice(&p1[..12]); - msgs::FollowerSf2 { verify_data } - } -} diff --git a/tls/tls-2pc-core/src/prf/leader.rs b/tls/tls-2pc-core/src/prf/leader.rs deleted file mode 100644 index 42341e6af..000000000 --- a/tls/tls-2pc-core/src/prf/leader.rs +++ /dev/null @@ -1,423 +0,0 @@ -use super::{ - sha::finalize_sha256_digest, - utils::{seed_cf, seed_ke, seed_ms, seed_sf}, -}; -use crate::msgs::prf as msgs; - -pub mod state { - mod sealed { - pub trait Sealed {} - impl Sealed for super::Ms1 {} - impl Sealed for super::Ms2 {} - impl Sealed for super::Ms3 {} - impl Sealed for super::MsComplete {} - impl Sealed for super::Ke1 {} - impl Sealed for super::Ke2 {} - impl Sealed for super::Ke3 {} - impl Sealed for super::KeComplete {} - impl Sealed for super::Cf1 {} - impl Sealed for super::Cf2 {} - impl Sealed for super::Cf3 {} - impl Sealed for super::Sf1 {} - impl Sealed for super::Sf2 {} - impl Sealed for super::Sf3 {} - } - - pub trait State: sealed::Sealed {} - - pub struct Ms1 {} - pub struct Ms2 { - pub(super) seed_ms: [u8; 77], - pub(super) inner_hash_state: [u32; 8], - pub(super) client_random: [u8; 32], - pub(super) server_random: [u8; 32], - } - pub struct Ms3 { - pub(super) seed_ms: [u8; 77], - pub(super) inner_hash_state: [u32; 8], - pub(super) p1_inner_hash: [u8; 32], - pub(super) client_random: [u8; 32], - pub(super) server_random: [u8; 32], - } - pub struct MsComplete { - pub(super) p1_inner_hash: [u8; 32], - pub(super) client_random: [u8; 32], - pub(super) server_random: [u8; 32], - } - pub struct Ke1 { - pub(super) client_random: [u8; 32], - pub(super) server_random: [u8; 32], - } - pub struct Ke2 { - pub(super) seed_ke: [u8; 77], - pub(super) inner_hash_state: [u32; 8], - } - pub struct Ke3 { - pub(super) seed_ke: [u8; 77], - pub(super) inner_hash_state: [u32; 8], - pub(super) a1: [u8; 32], - } - pub struct KeComplete { - pub(super) inner_hash_state: [u32; 8], - pub(super) p1_inner_hash: [u8; 32], - pub(super) p2_inner_hash: [u8; 32], - } - pub struct Cf1 { - pub(super) inner_hash_state: [u32; 8], - } - pub struct Cf2 { - pub(super) seed_cf: [u8; 47], - pub(super) inner_hash_state: [u32; 8], - } - pub struct Cf3 { - pub(super) inner_hash_state: [u32; 8], - } - pub struct Sf1 { - pub(super) inner_hash_state: [u32; 8], - } - pub struct Sf2 { - pub(super) inner_hash_state: [u32; 8], - pub(super) seed_sf: [u8; 47], - } - pub struct Sf3 {} - - impl State for Ms1 {} - impl State for Ms2 {} - impl State for Ms3 {} - impl State for MsComplete {} - impl State for Ke1 {} - impl State for Ke2 {} - impl State for Ke3 {} - impl State for KeComplete {} - impl State for Cf1 {} - impl State for Cf2 {} - impl State for Cf3 {} - impl State for Sf1 {} - impl State for Sf2 {} - impl State for Sf3 {} -} - -use state::*; - -pub struct PRFLeader { - state: S, -} - -impl PRFLeader { - /// Creates new PRF leader - pub fn new() -> PRFLeader { - PRFLeader { state: Ms1 {} } - } - - /// Computes a1 inner hash - /// ```text - /// H((pms xor ipad) || seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next( - self, - client_random: [u8; 32], - server_random: [u8; 32], - inner_hash_state: [u32; 8], - ) -> (msgs::LeaderMs1, PRFLeader) { - let seed_ms = seed_ms(&client_random, &server_random); - let a1_inner_hash = finalize_sha256_digest(inner_hash_state, 64, &seed_ms); - ( - msgs::LeaderMs1 { a1_inner_hash }, - PRFLeader { - state: Ms2 { - seed_ms, - inner_hash_state, - client_random, - server_random, - }, - }, - ) - } -} - -impl Default for PRFLeader { - fn default() -> Self { - Self::new() - } -} - -impl PRFLeader { - /// Computes p1 and a2 inner hashes - /// ```text - /// a2_inner_hash = H((pms xor ipad) || a1) - /// p1_inner_hash = H((pms xor ipad) || a1 || seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, msg: msgs::FollowerMs1) -> (msgs::LeaderMs2, PRFLeader) { - // a1 || seed - let mut a1_seed = [0u8; 109]; - a1_seed[..32].copy_from_slice(&msg.a1); - a1_seed[32..].copy_from_slice(&seed_ms( - &self.state.client_random, - &self.state.server_random, - )); - let p1_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &a1_seed); - let a2_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &msg.a1); - ( - msgs::LeaderMs2 { a2_inner_hash }, - PRFLeader { - state: Ms3 { - seed_ms: self.state.seed_ms, - inner_hash_state: self.state.inner_hash_state, - p1_inner_hash, - client_random: self.state.client_random, - server_random: self.state.server_random, - }, - }, - ) - } -} - -impl PRFLeader { - /// Computes p2_inner_hash - /// ```text - /// p2_inner_hash = H((pms xor ipad) || a2 || seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, msg: msgs::FollowerMs2) -> (msgs::LeaderMs3, PRFLeader) { - let mut a2_seed = [0u8; 109]; - a2_seed[..32].copy_from_slice(&msg.a2); - a2_seed[32..].copy_from_slice(&self.state.seed_ms); - // p2 inner hash = H((pms xor ipad) || a2 || seed) - let p2_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &a2_seed); - ( - msgs::LeaderMs3 { p2_inner_hash }, - PRFLeader { - state: MsComplete { - p1_inner_hash: self.state.p1_inner_hash, - client_random: self.state.client_random, - server_random: self.state.server_random, - }, - }, - ) - } -} - -impl PRFLeader { - /// Returns master secret p1 inner hash - /// ```text - /// p1_inner_hash = H((pms xor ipad) || a1 || seed) - /// ``` - pub fn p1_inner_hash(&self) -> [u8; 32] { - self.state.p1_inner_hash - } - - /// Returns next state - pub fn next(self) -> PRFLeader { - PRFLeader { - state: Ke1 { - client_random: self.state.client_random, - server_random: self.state.server_random, - }, - } - } -} - -impl PRFLeader { - /// Computes a1 inner hash - /// ```text - /// a1_inner_hash = H((ms xor ipad) || seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, inner_hash_state: [u32; 8]) -> (msgs::LeaderKe1, PRFLeader) { - let seed_ke = seed_ke(&self.state.client_random, &self.state.server_random); - let a1_inner_hash = finalize_sha256_digest(inner_hash_state, 64, &seed_ke); - ( - msgs::LeaderKe1 { a1_inner_hash }, - PRFLeader { - state: Ke2 { - seed_ke, - inner_hash_state, - }, - }, - ) - } -} - -impl PRFLeader { - /// Computes a2_inner_hash - /// ```text - /// a2_inner_hash = H((ms xor ipad) || a1) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, msg: msgs::FollowerKe1) -> (msgs::LeaderKe2, PRFLeader) { - let a2_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &msg.a1); - ( - msgs::LeaderKe2 { a2_inner_hash }, - PRFLeader { - state: Ke3 { - seed_ke: self.state.seed_ke, - inner_hash_state: self.state.inner_hash_state, - a1: msg.a1, - }, - }, - ) - } -} - -impl PRFLeader { - /// Computes p1 and p2 inner hashes - /// ```text - /// p1_inner_hash = H((ms xor ipad) || a1 || seed) - /// p2_inner_hash = H((ms xor ipad) || a2 || seed) - /// ``` - /// Returns next state - pub fn next(self, msg: msgs::FollowerKe2) -> PRFLeader { - let mut a1_seed = [0u8; 109]; - a1_seed[..32].copy_from_slice(&self.state.a1); - a1_seed[32..].copy_from_slice(&self.state.seed_ke); - - let mut a2_seed = [0u8; 109]; - a2_seed[..32].copy_from_slice(&msg.a2); - a2_seed[32..].copy_from_slice(&self.state.seed_ke); - - // H((ms xor ipad) || a1 || seed) - let p1_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &a1_seed); - // H((ms xor ipad) || a2 || seed) - let p2_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &a2_seed); - - PRFLeader { - state: KeComplete { - inner_hash_state: self.state.inner_hash_state, - p1_inner_hash, - p2_inner_hash, - }, - } - } -} - -impl PRFLeader { - /// Returns p1 inner hash from key expansion - /// ```text - /// H((ms xor ipad) || a1 || seed) - /// ``` - pub fn p1_inner_hash(&self) -> [u8; 32] { - self.state.p1_inner_hash - } - - /// Returns p2 inner hash from key expansion - /// ```text - /// H((ms xor ipad) || a2 || seed) - /// ``` - pub fn p2_inner_hash(&self) -> [u8; 32] { - self.state.p2_inner_hash - } - - /// Returns next state - pub fn next(self) -> PRFLeader { - PRFLeader { - state: Cf1 { - inner_hash_state: self.state.inner_hash_state, - }, - } - } -} - -impl PRFLeader { - /// Computes a1 inner hash - /// ```text - /// H((ms xor ipad) || cf_seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, handshake_blob: &[u8]) -> (msgs::LeaderCf1, PRFLeader) { - let seed_cf = seed_cf(handshake_blob); - let a1_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &seed_cf); - ( - msgs::LeaderCf1 { a1_inner_hash }, - PRFLeader { - state: Cf2 { - seed_cf, - inner_hash_state: self.state.inner_hash_state, - }, - }, - ) - } -} - -impl PRFLeader { - /// Computes p1 inner hash - /// ```text - /// H((ms xor ipad) || a1 || cf_seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, msg: msgs::FollowerCf1) -> (msgs::LeaderCf2, PRFLeader) { - let mut a1_seed = [0u8; 79]; - a1_seed[..32].copy_from_slice(&msg.a1); - a1_seed[32..].copy_from_slice(&self.state.seed_cf); - let p1_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &a1_seed); - ( - msgs::LeaderCf2 { p1_inner_hash }, - PRFLeader { - state: Cf3 { - inner_hash_state: self.state.inner_hash_state, - }, - }, - ) - } -} - -impl PRFLeader { - /// Returns client finished verify_data and next state - pub fn next(self, msg: msgs::FollowerCf2) -> ([u8; 12], PRFLeader) { - ( - msg.verify_data, - PRFLeader { - state: Sf1 { - inner_hash_state: self.state.inner_hash_state, - }, - }, - ) - } -} - -impl PRFLeader { - /// Computes a1 inner hash - /// ```text - /// H((ms xor ipad) || sf_seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, handshake_blob: &[u8]) -> (msgs::LeaderSf1, PRFLeader) { - let seed_sf = seed_sf(handshake_blob); - let a1_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &seed_sf); - ( - msgs::LeaderSf1 { a1_inner_hash }, - PRFLeader { - state: Sf2 { - seed_sf, - inner_hash_state: self.state.inner_hash_state, - }, - }, - ) - } -} - -impl PRFLeader { - /// Computes p1 inner hash - /// ```text - /// H((ms xor ipad) || a1 || sf_seed) - /// ``` - /// Returns message to [`super::PRFFollower`] and next state - pub fn next(self, msg: msgs::FollowerSf1) -> (msgs::LeaderSf2, PRFLeader) { - let mut a1_seed = [0u8; 79]; - a1_seed[..32].copy_from_slice(&msg.a1); - a1_seed[32..].copy_from_slice(&self.state.seed_sf); - let sf_vd_inner_hash = finalize_sha256_digest(self.state.inner_hash_state, 64, &a1_seed); - ( - msgs::LeaderSf2 { sf_vd_inner_hash }, - PRFLeader { state: Sf3 {} }, - ) - } -} - -impl PRFLeader { - /// Returns server finished verify_data - pub fn next(self, msg: msgs::FollowerSf2) -> [u8; 12] { - msg.verify_data - } -} diff --git a/tls/tls-2pc-core/src/prf/mod.rs b/tls/tls-2pc-core/src/prf/mod.rs deleted file mode 100644 index 7e1c9de63..000000000 --- a/tls/tls-2pc-core/src/prf/mod.rs +++ /dev/null @@ -1,248 +0,0 @@ -//! This module contains the protocol for computing TLS SHA-256 HMAC PRF using 2PC in such a way -//! that neither party learns the session keys, rather they learn respective XOR shares of the keys. -//! -//! For a more comprehensive explanation of this protocol see our [documentation](https://tlsnotary.github.io/docs-mdbook) -//! -//! To save some compute and bandwidth, the PRF can be broken down into smaller units where some can be -//! computed without using 2PC. -//! -//! To elaborate, recall how HMAC is computed (assuming |k| <= block size): -//! -//! HMAC(k, m) = H((k ⊕ opad) | H((k ⊕ ipad) | m)) -//! -//! Notice that both H(k ⊕ opad) and H(k ⊕ ipad) can be computed separately prior to finalization. In this -//! codebase we name these units as such: -//! - Outer hash state: H(k ⊕ opad) -//! - Inner hash state: H(k ⊕ ipad) -//! - Inner hash: H((k ⊕ ipad) | m) -//! -//! In TLS, the master secret is computed like so: -//! -//! ```text -//! seed = "master secret" | client_random | server_random -//! a0 = seed -//! a1 = HMAC(pms, a0) -//! a2 = HMAC(pms, a1) -//! p1 = HMAC(pms, a1 | seed) -//! p2 = HMAC(pms, a2 | seed) -//! ms = (p1 | p2)[:48] -//! ``` -//! -//! Notice that in each step the key, in this case PMS, is constant. Thus both the outer and inner hash state can be reused -//! for each step. -//! -//! Here is a small illustration of what this looks like: -//! -//! ```text -//! +------------+ +------------+ -//! | | | | -//! | Leader | | Follower | -//! | | | | -//! +-----+------+ +-----+------+ -//! | | -//! | PMS SHARE +-----------+ PMS SHARE | -//! +----------------------> | | <--------------------+ -//! | | 2PC | | -//! | <----------------------+ +--------------------> | -//! | INNER HASH +-----------+ OUTER HASH | -//! | STATE STATE | -//! | H(PMS ⊕ ipad) H(PMS ⊕ opad) | -//! | | -//! -//! H((PMS ⊕ ipad)|seed) ------------------> H((PMS ⊕ opad))|H((PMS ⊕ ipad)|seed))=a1 -//! -//! a1 | -//! <---------------------------------------------------------+ -//! ``` -//! -//! Following, the master secret is expanded to the session keys like so: -//! -//! ```text -//! seed = "key expansion" | server_random | client_random -//! a0 = seed -//! a1 = HMAC(ms, a0) -//! a2 = HMAC(ms, a1) -//! p1 = HMAC(ms, a1 | seed) -//! p2 = HMAC(ms, a2 | seed) -//! ek = (p1 | p2)[:40] -//! cwk = ek[:16] -//! swk = ek[16:32] -//! civ = ek[32:36] -//! siv = ek[36:40] -//! ``` - -mod follower; -mod leader; -pub mod sha; -pub mod utils; - -pub use crate::msgs::prf::PRFMessage; -pub use follower::{state as follower_state, PRFFollower}; -pub use leader::{state as leader_state, PRFLeader}; - -#[cfg(test)] -mod tests { - use self::utils::*; - use super::*; - use hex; - use sha::{finalize_sha256_digest, partial_sha256_digest}; - - #[test] - fn test_prf() { - let client_random = [0x01_u8; 32]; - let server_random = [0x02_u8; 32]; - let pms = [0x03_u8; 32]; - - let (ipad, opad) = generate_hmac_pads(&pms); - - // H(pms xor ipad) - let inner_hash_state = partial_sha256_digest(&ipad); - // H(pms xor opad) - let outer_hash_state = partial_sha256_digest(&opad); - - let leader = PRFLeader::new(); - let follower = PRFFollower::new(); - - let (leader_msg, leader) = leader.next(client_random, server_random, inner_hash_state); - let (follower_msg, follower) = follower.next(outer_hash_state, leader_msg); - - // H((pms xor opad) || H((pms xor ipad) || seed)) - let a1 = follower_msg.a1; - assert_eq!( - &a1, - &hmac_sha256(&pms, &seed_ms(&client_random, &server_random)) - ); - - let (leader_msg, leader) = leader.next(follower_msg); - let (follower_msg, follower) = follower.next(leader_msg); - - // H((pms xor opad) || H((pms xor ipad) || a1)) - let a2 = follower_msg.a2; - assert_eq!(&a2, &hmac_sha256(&pms, &a1)); - - let (leader_msg, leader) = leader.next(follower_msg); - // H((pms xor opad) || H((pms xor ipad) || a2 || seed)) - let follower = follower.next(leader_msg); - - // a1 || seed - let mut a1_seed = [0u8; 109]; - a1_seed[..32].copy_from_slice(&a1); - a1_seed[32..].copy_from_slice(&seed_ms(&client_random, &server_random)); - // H((pms xor opad) || H((pms xor ipad) || a1 || seed)) - let inner_hash = finalize_sha256_digest(inner_hash_state, 64, &a1_seed); - assert_eq!(inner_hash, leader.p1_inner_hash()); - - let leader = leader.next(); - - // a2 || seed - let mut a2_seed = [0u8; 109]; - a2_seed[..32].copy_from_slice(&a2); - a2_seed[32..].copy_from_slice(&seed_ms(&client_random, &server_random)); - let p2 = hmac_sha256(&pms, &a2_seed); - assert_eq!(follower.p2(), p2); - - let follower = follower.next(); - - let p1 = finalize_sha256_digest(outer_hash_state, 64, &inner_hash); - - let mut ms = [0u8; 48]; - ms[..32].copy_from_slice(&p1); - ms[32..48].copy_from_slice(&p2[..16]); - - let (ipad, opad) = generate_hmac_pads(&ms); - - // H(ms xor ipad) - let inner_hash_state = partial_sha256_digest(&ipad); - // H(ms xor opad) - let outer_hash_state = partial_sha256_digest(&opad); - - let (leader_msg, leader) = leader.next(inner_hash_state); - let (follower_msg, follower) = follower.next(outer_hash_state).next(leader_msg); - - // H((ms xor opad) || H((ms xor ipad) || seed)) - let a1 = follower_msg.a1; - assert_eq!( - &a1, - &hmac_sha256(&ms, &seed_ke(&client_random, &server_random)) - ); - - let (leader_msg, leader) = leader.next(follower_msg); - let (follower_msg, follower) = follower.next(leader_msg); - - // H((ms xor opad) || H((ms xor ipad) || a1)) - let a2 = follower_msg.a2; - assert_eq!(&a2, &hmac_sha256(&ms, &a1)); - - let leader = leader.next(follower_msg); - - let p1 = finalize_sha256_digest(outer_hash_state, 64, &leader.p1_inner_hash()); - let p2 = finalize_sha256_digest(outer_hash_state, 64, &leader.p2_inner_hash()); - - let leader = leader.next(); - - let mut ek = [0u8; 40]; - ek[..32].copy_from_slice(&p1); - ek[32..].copy_from_slice(&p2[..8]); - - let handshake_blob = [0x04_u8; 256]; - let (leader_msg, leader) = leader.next(&handshake_blob); - let (follower_msg, follower) = follower.next(leader_msg); - - // H((ms xor opad) || H((ms xor ipad) || seed)) - let a1 = follower_msg.a1; - assert_eq!(&a1, &hmac_sha256(&ms, &seed_cf(&handshake_blob))); - - let (leader_msg, leader) = leader.next(follower_msg); - let (follower_msg, follower) = follower.next(leader_msg); - - // H((ms xor opad) || H((ms xor ipad) || a1 || seed)) - let vd = follower_msg.verify_data; - // a1 || seed - let mut a1_seed = [0u8; 79]; - a1_seed[..32].copy_from_slice(&a1); - a1_seed[32..].copy_from_slice(&seed_cf(&handshake_blob)); - assert_eq!(&vd, &hmac_sha256(&ms, &a1_seed)[..12]); - - let (cfvd, leader) = leader.next(follower_msg); - - let (leader_msg, leader) = leader.next(&handshake_blob); - let (follower_msg, follower) = follower.next(leader_msg); - - // H((ms xor opad) || H((ms xor ipad) || seed)) - let a1 = follower_msg.a1; - assert_eq!(&a1, &hmac_sha256(&ms, &seed_sf(&handshake_blob))); - - let (leader_msg, leader) = leader.next(follower_msg); - let follower_msg = follower.next(leader_msg); - - // H((ms xor opad) || H((ms xor ipad) || a1 || seed)) - let vd = follower_msg.verify_data; - // a1 || seed - let mut a1_seed = [0u8; 79]; - a1_seed[..32].copy_from_slice(&a1); - a1_seed[32..].copy_from_slice(&seed_sf(&handshake_blob)); - assert_eq!(&vd, &hmac_sha256(&ms, &a1_seed)[..12]); - - let sfvd = leader.next(follower_msg); - - // reference values were computed with python3: - // import scapy - // from scapy.layers.tls.crypto import prf - // prffn = prf.PRF() - // cr = bytes([0x01]*32) - // sr = bytes([0x02]*32) - // pms = bytes([0x03]*32) - // handshake_blob = bytes([0x04]*256) - // ms = prffn.compute_master_secret(pms, cr, sr) - // print(prffn.derive_key_block(ms, sr, cr, 40).hex()) - // print(prffn.compute_verify_data("client", "write", handshake_blob, ms).hex()) - // print(prffn.compute_verify_data("server", "write", handshake_blob, ms).hex()) - let reference_ek = - "ede91cf0898c0ac272f1035fe20a8d24d90a6d3bf8be815b4a144cb270e3b8c8e00f2af71471ced8"; - let reference_cfvd = "dc9906a43d25742bc6a479c2"; - let reference_sfvd = "d9f56d1223dea4832a7d8295"; - assert_eq!(hex::encode(ek), reference_ek); - assert_eq!(hex::encode(cfvd), reference_cfvd); - assert_eq!(hex::encode(sfvd), reference_sfvd); - } -} diff --git a/tls/tls-2pc-core/src/prf/sha.rs b/tls/tls-2pc-core/src/prf/sha.rs deleted file mode 100644 index 800a0bbc6..000000000 --- a/tls/tls-2pc-core/src/prf/sha.rs +++ /dev/null @@ -1,81 +0,0 @@ -use core::slice::from_ref; -use digest::{ - block_buffer::{BlockBuffer, Eager}, - generic_array::GenericArray, - typenum::U64, -}; -use sha2::compress256; - -#[allow(dead_code)] -#[inline] -pub fn partial_sha256_digest(input: &[u8]) -> [u32; 8] { - if input.len() % 64 != 0 { - panic!("input length must be a multiple of 64"); - } - let mut state = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, - 0x5be0cd19, - ]; - for b in input.chunks_exact(64) { - let mut block = GenericArray::::default(); - block[..].copy_from_slice(b); - compress256(&mut state, &[block]); - } - state -} - -/// Takes existing state from SHA2 hash and finishes it with additional data -#[inline] -pub fn finalize_sha256_digest(mut state: [u32; 8], pos: usize, input: &[u8]) -> [u8; 32] { - let mut buffer = BlockBuffer::::default(); - buffer.digest_blocks(input, |b| compress256(&mut state, b)); - buffer.digest_pad( - 0x80, - &(((input.len() + pos) * 8) as u64).to_be_bytes(), - |b| compress256(&mut state, from_ref(b)), - ); - - let mut out: [u8; 32] = [0; 32]; - for (chunk, v) in out.chunks_exact_mut(4).zip(state.iter()) { - chunk.copy_from_slice(&v.to_be_bytes()); - } - out -} - -#[cfg(test)] -mod tests { - use super::*; - use sha2::{Digest, Sha256}; - - #[test] - fn test_sha2_initial_state() { - let s = b"test string"; - - // initial state for sha2 - let state = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, - 0x5be0cd19, - ]; - let digest = finalize_sha256_digest(state, 0, s); - - let mut hasher = Sha256::new(); - hasher.update(s); - assert_eq!(digest, hasher.finalize().as_slice()); - } - - #[test] - fn test_sha2_resume_state() { - let s = b"test string test string test string test string test string test"; - - let state = partial_sha256_digest(s); - - let s2 = b"additional data "; - - let digest = finalize_sha256_digest(state, s.len(), s2); - - let mut hasher = Sha256::new(); - hasher.update(s); - hasher.update(s2); - assert_eq!(digest, hasher.finalize().as_slice()); - } -} diff --git a/tls/tls-2pc-core/src/prf/utils.rs b/tls/tls-2pc-core/src/prf/utils.rs deleted file mode 100644 index b5c00b0f9..000000000 --- a/tls/tls-2pc-core/src/prf/utils.rs +++ /dev/null @@ -1,64 +0,0 @@ -#![allow(dead_code)] - -use digest::Digest; -use hmac::{Hmac, Mac}; -use sha2::Sha256; -use std::convert::TryInto; - -type HmacSha256 = Hmac; - -pub fn hmac_sha256(key: &[u8], input: &[u8]) -> [u8; 32] { - let mut mac = HmacSha256::new_from_slice(key).unwrap(); - mac.update(input); - let out = mac.finalize().into_bytes(); - out[..32] - .try_into() - .expect("expected output to be 32 bytes") -} - -pub fn generate_hmac_pads(input: &[u8]) -> ([u8; 64], [u8; 64]) { - let mut ipad = [0x36_u8; 64]; - let mut opad = [0x5c_u8; 64]; - - for (ipad, input) in ipad.iter_mut().zip(input.iter()) { - *ipad ^= *input; - } - for (opad, input) in opad.iter_mut().zip(input.iter()) { - *opad ^= *input; - } - (ipad, opad) -} - -pub fn seed_ms(client_random: &[u8; 32], server_random: &[u8; 32]) -> [u8; 77] { - let mut seed = [0u8; 77]; - seed[..13].copy_from_slice(b"master secret"); - seed[13..45].copy_from_slice(client_random); - seed[45..].copy_from_slice(server_random); - seed -} - -pub fn seed_ke(client_random: &[u8; 32], server_random: &[u8; 32]) -> [u8; 77] { - let mut seed = [0u8; 77]; - seed[..13].copy_from_slice(b"key expansion"); - seed[13..45].copy_from_slice(server_random); - seed[45..].copy_from_slice(client_random); - seed -} - -pub fn seed_cf(handshake_blob: &[u8]) -> [u8; 47] { - let mut hasher = Sha256::new(); - hasher.update(handshake_blob); - let mut seed = [0u8; 47]; - seed[..15].copy_from_slice(b"client finished"); - seed[15..].copy_from_slice(hasher.finalize().as_slice()); - seed -} - -pub fn seed_sf(handshake_blob: &[u8]) -> [u8; 47] { - let mut hasher = Sha256::new(); - hasher.update(handshake_blob); - let mut seed = [0u8; 47]; - seed[..15].copy_from_slice(b"server finished"); - seed[15..].copy_from_slice(hasher.finalize().as_slice()); - seed -} diff --git a/tls/tls-2pc-core/tests/circuit_test.rs b/tls/tls-2pc-core/tests/circuit_test.rs deleted file mode 100644 index fc7f48d7b..000000000 --- a/tls/tls-2pc-core/tests/circuit_test.rs +++ /dev/null @@ -1,666 +0,0 @@ -// Here we test all the c*.bin circuits from ../circuits - -use aes::{ - cipher::{generic_array::GenericArray, BlockEncrypt, NewBlockCipher}, - Aes128, -}; -use hex::FromHex; -use mpc_circuits::Circuit; -use mpc_core::utils::{boolvec_to_u8vec, u8vec_to_boolvec, xor}; -use num::{bigint::RandBigInt, BigUint, Zero}; -use rand::{thread_rng, Rng}; -use tls_2pc_core::{ - handshake::sha, CIRCUIT_1_BYTES, CIRCUIT_2_BYTES, CIRCUIT_3_BYTES, CIRCUIT_4_BYTES, - CIRCUIT_5_BYTES, CIRCUIT_6_BYTES, CIRCUIT_7_BYTES, -}; - -/// NIST P-256 Prime -pub const P: &str = "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"; - -// Evaluates the circuit "name" with the given inputs (and their sizes in bits) -// and the expected bitsize of outputs. Returns individual outputs as bytes. -fn evaluate_circuit( - circ: &Circuit, - inputs: Vec>, - input_sizes: Vec, - output_sizes: Vec, -) -> Vec> { - // Each circuit's input's bit order is "least bits first". That's why we reverse each input - // individually. (The individual circuit inputs are listed at the top of the c*.casm files) - let mut all_inputs: Vec> = Vec::with_capacity(inputs.len()); - for i in 0..inputs.len() { - // truncate the input to the exact amount of bits if necessary - let mut tmp = - u8vec_to_boolvec(&inputs[i])[(inputs[i].len() * 8 - input_sizes[i])..].to_vec(); - tmp.reverse(); - all_inputs.push(tmp); - } - let inputs: Vec = all_inputs.into_iter().flatten().collect(); - - // same as with inputs, the outputs are "least bit first" and must be reversed individually - let mut output = circ.evaluate(&inputs).unwrap(); - - let mut outputs: Vec> = Vec::with_capacity(output_sizes.len()); - let mut pos: usize = 0; - for i in 0..output_sizes.len() { - let tmp = &mut output[pos..pos + output_sizes[i]]; - tmp.reverse(); - outputs.push(boolvec_to_u8vec(&tmp)); - pos += output_sizes[i]; - } - outputs -} - -// Tests correctness of the c1.casm circuit -fn circuit1(circ: &Circuit, u_share: BigUint, n_share: BigUint) { - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let prime = <[u8; 32]>::from_hex(P).unwrap(); - let prime = BigUint::from_bytes_be(&prime); - - // * generate user's and notary's inside-the-GC-masks to mask the GC output - let mask_n: [u8; 32] = rng.gen(); - let mask_u: [u8; 32] = rng.gen(); - - // reduce pms mod prime if necessary - let pms = (u_share.clone() + n_share.clone()) % prime; - - // * XOR pms (zero-padded to 64 bytes) with inner/outer padding of HMAC - let ipad_64x = [0x36u8; 64]; - let opad_64x = [0x5cu8; 64]; - let mut pms_zeropadded = [0u8; 64]; - pms_zeropadded[0..32].copy_from_slice(&pms.to_bytes_be()); - - let mut pms_ipad = [0u8; 64]; - let mut pms_opad = [0u8; 64]; - xor(&ipad_64x, &pms_zeropadded, &mut pms_ipad); - xor(&opad_64x, &pms_zeropadded, &mut pms_opad); - - // * hash the padded PMS - let ohash_state = sha::partial_sha256_digest(&pms_opad); - let ihash_state = sha::partial_sha256_digest(&pms_ipad); - // convert into u8 array - let ohash_state_u8: Vec = ohash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - let ihash_state_u8: Vec = ihash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - - // * masked hash state are the expected circuit's outputs - let mut expected1 = [0u8; 32]; - let mut expected2 = [0u8; 32]; - xor(&ohash_state_u8, &mask_n, &mut expected1); - xor(&ihash_state_u8, &mask_u, &mut expected2); - - // Evaluate the circuit. - let outputs = evaluate_circuit( - circ, - vec![ - n_share.to_bytes_be().to_vec(), - mask_n.to_vec(), - u_share.to_bytes_be().to_vec(), - mask_u.to_vec(), - ], - vec![256, 256, 256, 256], - vec![256, 256], - ); - assert_eq!(expected1.to_vec(), outputs[0]); - assert_eq!(expected2.to_vec(), outputs[1]); -} - -#[test] -// Test the circuit's code path when the sum of PMS shares DOES NOT overflow the prime -// and MUST NOT be reduced. -fn circuit1_no_overflow() { - let mut rng = thread_rng(); - let circ = Circuit::load_bytes(CIRCUIT_1_BYTES).unwrap(); - - let prime = <[u8; 32]>::from_hex(P).unwrap(); - let prime = BigUint::from_bytes_be(&prime); - - loop { - // * generate user's and notary's random PMS shares in the field - let n_share = rng.gen_biguint_range(&BigUint::zero(), &prime); - let u_share = rng.gen_biguint_range(&BigUint::zero(), &prime); - if (u_share.clone() + n_share.clone()) < prime { - circuit1(&circ, u_share, n_share); - break; - } - } -} - -#[test] -// Test the circuit's code path when the sum of PMS shares DOES overflow the prime -// and MUST be reduced. -fn circuit1_with_overflow() { - let mut rng = thread_rng(); - let circ = Circuit::load_bytes(CIRCUIT_1_BYTES).unwrap(); - - let prime = <[u8; 32]>::from_hex(P).unwrap(); - let prime = BigUint::from_bytes_be(&prime); - - loop { - // * generate user's and notary's random PMS shares in the field - let n_share = rng.gen_biguint_range(&BigUint::zero(), &prime); - let u_share = rng.gen_biguint_range(&BigUint::zero(), &prime); - if (u_share.clone() + n_share.clone()) >= prime { - circuit1(&circ, u_share, n_share); - break; - } - } -} - -#[test] -// Tests correctness of the c2.casm circuit -fn circuit2() { - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_outer_hash_state: [u8; 32] = rng.gen(); - let n_output_mask: [u8; 32] = rng.gen(); - let u_inner_hash_p1: [u8; 32] = rng.gen(); - let u_p2: [u8; 16] = rng.gen(); - let u_output_mask: [u8; 32] = rng.gen(); - - // convert outer_hash_state to the expected type [u32; 8] - let mut n_outer_hash_state_u32 = [0u32; 8]; - for i in 0..8 { - let mut tmp = [0u8; 4]; - tmp.copy_from_slice(&n_outer_hash_state[i * 4..(i + 1) * 4]); - n_outer_hash_state_u32[i] = u32::from_be_bytes(tmp); - } - - // finalize the hash to get p1 - let p1 = sha::finalize_sha256_digest(n_outer_hash_state_u32, 64, &u_inner_hash_p1); - // get master_secret - let mut ms = [0u8; 48]; - ms[..32].copy_from_slice(&p1); - ms[32..48].copy_from_slice(&u_p2[..16]); - - // * XOR ms (zero-padded to 64 bytes) with inner/outer padding of HMAC - let ipad_64x = [0x36u8; 64]; - let opad_64x = [0x5cu8; 64]; - let mut ms_zeropadded = [0u8; 64]; - ms_zeropadded[0..48].copy_from_slice(&ms); - - let mut ms_ipad = [0u8; 64]; - let mut ms_opad = [0u8; 64]; - xor(&ipad_64x, &ms_zeropadded, &mut ms_ipad); - xor(&opad_64x, &ms_zeropadded, &mut ms_opad); - - // * hash the padded MS - let ohash_state = sha::partial_sha256_digest(&ms_opad); - let ihash_state = sha::partial_sha256_digest(&ms_ipad); - // convert into u8 array - let ohash_state_u8: Vec = ohash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - let ihash_state_u8: Vec = ihash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - - // * masked hash state are the expected circuit's outputs - let mut expected1 = [0u8; 32]; - let mut expected2 = [0u8; 32]; - xor(&ohash_state_u8, &n_output_mask, &mut expected1); - xor(&ihash_state_u8, &u_output_mask, &mut expected2); - - // Evaluate the circuit. - let outputs = evaluate_circuit( - &Circuit::load_bytes(CIRCUIT_2_BYTES).unwrap(), - vec![ - n_outer_hash_state.to_vec(), - n_output_mask.to_vec(), - u_inner_hash_p1.to_vec(), - u_p2.to_vec(), - u_output_mask.to_vec(), - ], - vec![256, 256, 256, 128, 256], - vec![256, 256], - ); - assert_eq!(expected1.to_vec(), outputs[0]); - assert_eq!(expected2.to_vec(), outputs[1]); -} - -#[test] -// Tests correctness of the c3.casm circuit -fn circuit3() { - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_outer_hash_state: [u8; 32] = rng.gen(); - let n_output_mask1: [u8; 16] = rng.gen(); - let n_output_mask2: [u8; 16] = rng.gen(); - let n_output_mask3: [u8; 4] = rng.gen(); - let n_output_mask4: [u8; 4] = rng.gen(); - let u_inner_hash_p1: [u8; 32] = rng.gen(); - let u_inner_hash_p2: [u8; 32] = rng.gen(); - let u_output_mask1: [u8; 16] = rng.gen(); - let u_output_mask2: [u8; 16] = rng.gen(); - let u_output_mask3: [u8; 4] = rng.gen(); - let u_output_mask4: [u8; 4] = rng.gen(); - - // convert outer_hash_state to the expected type [u32; 8] - let mut n_outer_hash_state_u32 = [0u32; 8]; - for i in 0..8 { - let mut tmp = [0u8; 4]; - tmp.copy_from_slice(&n_outer_hash_state[i * 4..(i + 1) * 4]); - n_outer_hash_state_u32[i] = u32::from_be_bytes(tmp); - } - - // finalize the hash to get p1 - let p1 = sha::finalize_sha256_digest(n_outer_hash_state_u32, 64, &u_inner_hash_p1); - // finalize the hash to get p2 - let p2 = sha::finalize_sha256_digest(n_outer_hash_state_u32, 64, &u_inner_hash_p2); - - // get expanded_keys (TLS session keys) - let mut ek = [0u8; 40]; - ek[..32].copy_from_slice(&p1); - ek[32..40].copy_from_slice(&p2[0..8]); - // split into client/server_write_key and client/server_write_iv - let mut cwk = [0u8; 16]; - cwk.copy_from_slice(&ek[0..16]); - let mut swk = [0u8; 16]; - swk.copy_from_slice(&ek[16..32]); - let mut civ = [0u8; 4]; - civ.copy_from_slice(&ek[32..36]); - let mut siv = [0u8; 4]; - siv.copy_from_slice(&ek[36..40]); - - // XOR each keys with Notary's mask and then with User's mask - let mut swk_masked = [0u8; 16]; - xor(&swk, &n_output_mask1, &mut swk_masked); - xor(&swk_masked.clone(), &u_output_mask1, &mut swk_masked); - let mut cwk_masked = [0u8; 16]; - xor(&cwk, &n_output_mask2, &mut cwk_masked); - xor(&cwk_masked.clone(), &u_output_mask2, &mut cwk_masked); - let mut siv_masked = [0u8; 4]; - xor(&siv, &n_output_mask3, &mut siv_masked); - xor(&siv_masked.clone(), &u_output_mask3, &mut siv_masked); - let mut civ_masked = [0u8; 4]; - xor(&civ, &n_output_mask4, &mut civ_masked); - xor(&civ_masked.clone(), &u_output_mask4, &mut civ_masked); - - // Evaluate the circuit. - let outputs = evaluate_circuit( - &Circuit::load_bytes(CIRCUIT_3_BYTES).unwrap(), - vec![ - n_outer_hash_state.to_vec(), - n_output_mask1.to_vec(), - n_output_mask2.to_vec(), - n_output_mask3.to_vec(), - n_output_mask4.to_vec(), - u_inner_hash_p1.to_vec(), - u_inner_hash_p2.to_vec(), - u_output_mask1.to_vec(), - u_output_mask2.to_vec(), - u_output_mask3.to_vec(), - u_output_mask4.to_vec(), - ], - vec![256, 128, 128, 32, 32, 256, 256, 128, 128, 32, 32], - vec![128, 128, 32, 32], - ); - assert_eq!(swk_masked.to_vec(), outputs[0]); - assert_eq!(cwk_masked.to_vec(), outputs[1]); - assert_eq!(siv_masked.to_vec(), outputs[2]); - assert_eq!(civ_masked.to_vec(), outputs[3]); -} - -#[test] -// Tests correctness of the c4.casm circuit -fn circuit4() { - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_swk: [u8; 16] = rng.gen(); - let n_cwk: [u8; 16] = rng.gen(); - let n_siv: [u8; 4] = rng.gen(); - let n_civ: [u8; 4] = rng.gen(); - let n_output_mask5: [u8; 16] = rng.gen(); - let n_output_mask6: [u8; 16] = rng.gen(); - let u_swk: [u8; 16] = rng.gen(); - let u_cwk: [u8; 16] = rng.gen(); - let u_siv: [u8; 4] = rng.gen(); - let u_civ: [u8; 4] = rng.gen(); - let u_output_mask5: [u8; 16] = rng.gen(); - let u_output_mask6: [u8; 16] = rng.gen(); - let u_output_mask7: [u8; 16] = rng.gen(); - - // combine key shares - let mut swk = [0u8; 16]; - xor(&n_swk, &u_swk, &mut swk); - let mut cwk = [0u8; 16]; - xor(&n_cwk, &u_cwk, &mut cwk); - let mut siv = [0u8; 4]; - xor(&n_siv, &u_siv, &mut siv); - let mut civ = [0u8; 4]; - xor(&n_civ, &u_civ, &mut civ); - - // set AES key - let key = GenericArray::clone_from_slice(&cwk); - let cipher = Aes128::new(&key); - - // AES-ECB encrypt 0, get MAC key - let mut z = GenericArray::clone_from_slice(&[0u8; 16]); - cipher.encrypt_block(&mut z); - let mac_key = z; - - // AES-ECB encrypt a block with counter==1 and nonce==1, get GCTR block - let nonce: [u8; 8] = 1u64.to_be_bytes(); - let counter: [u8; 4] = 1u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&civ); - msg[4..12].copy_from_slice(&nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let gctr_block = msg; - - // AES-ECB encrypt a block with counter==2 and nonce==1 - let nonce: [u8; 8] = 1u64.to_be_bytes(); - let counter: [u8; 4] = 2u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&civ); - msg[4..12].copy_from_slice(&nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let first_block = msg; - - // XOR MAC key and GCTR block with Notary's mask and then with User's mask - let mut mac_key_masked = [0u8; 16]; - xor(&mac_key, &n_output_mask5, &mut mac_key_masked); - xor( - &mac_key_masked.clone(), - &u_output_mask5, - &mut mac_key_masked, - ); - let mut gctr_block_masked = [0u8; 16]; - xor(&gctr_block, &n_output_mask6, &mut gctr_block_masked); - xor( - &gctr_block_masked.clone(), - &u_output_mask6, - &mut gctr_block_masked, - ); - - // XOR the first block with User's mask - let mut first_block_masked = [0u8; 16]; - xor(&first_block, &u_output_mask7, &mut first_block_masked); - - // Evaluate the circuit. - let outputs = evaluate_circuit( - &Circuit::load_bytes(CIRCUIT_4_BYTES).unwrap(), - vec![ - n_swk.to_vec(), - n_cwk.to_vec(), - n_siv.to_vec(), - n_civ.to_vec(), - n_output_mask5.to_vec(), - n_output_mask6.to_vec(), - u_swk.to_vec(), - u_cwk.to_vec(), - u_siv.to_vec(), - u_civ.to_vec(), - u_output_mask5.to_vec(), - u_output_mask6.to_vec(), - u_output_mask7.to_vec(), - ], - vec![128, 128, 32, 32, 128, 128, 128, 128, 32, 32, 128, 128, 128], - vec![128, 128, 128], - ); - assert_eq!(mac_key_masked.to_vec(), outputs[0]); - assert_eq!(gctr_block_masked.to_vec(), outputs[1]); - assert_eq!(first_block_masked.to_vec(), outputs[2]); -} - -#[test] -// Tests correctness of the c5.casm circuit -fn circuit5() { - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_outer_hash_state_p1: [u8; 32] = rng.gen(); - let n_swk: [u8; 16] = rng.gen(); - let n_siv: [u8; 4] = rng.gen(); - let n_output_mask1: [u8; 16] = rng.gen(); - let n_output_mask2: [u8; 16] = rng.gen(); - let u_inner_hash_state_p1: [u8; 32] = rng.gen(); - let u_swk: [u8; 16] = rng.gen(); - let u_siv: [u8; 4] = rng.gen(); - let u_server_finished_nonce: [u8; 8] = rng.gen(); - let u_output_mask1: [u8; 16] = rng.gen(); - let u_output_mask2: [u8; 16] = rng.gen(); - let u_output_mask3: [u8; 16] = rng.gen(); - let u_output_mask4: [u8; 12] = rng.gen(); - - // convert outer_hash_state to the expected type [u32; 8] - let mut n_outer_hash_state_p1_u32 = [0u32; 8]; - for i in 0..8 { - let mut tmp = [0u8; 4]; - tmp.copy_from_slice(&n_outer_hash_state_p1[i * 4..(i + 1) * 4]); - n_outer_hash_state_p1_u32[i] = u32::from_be_bytes(tmp); - } - - // finalize the hash to get p1 - let p1 = sha::finalize_sha256_digest(n_outer_hash_state_p1_u32, 64, &u_inner_hash_state_p1); - let mut verify_data = [0u8; 12]; - verify_data.copy_from_slice(&p1[0..12]); - - // combine key shares - let mut swk = [0u8; 16]; - xor(&n_swk, &u_swk, &mut swk); - let mut siv = [0u8; 4]; - xor(&n_siv, &u_siv, &mut siv); - - // set AES key - let key = GenericArray::clone_from_slice(&swk); - let cipher = Aes128::new(&key); - - // AES-ECB encrypt 0, get MAC key - let mut z = GenericArray::clone_from_slice(&[0u8; 16]); - cipher.encrypt_block(&mut z); - let mac_key = z; - - // AES-ECB encrypt a block with counter==1 and nonce from Server_Finished, get GCTR block - let counter: [u8; 4] = 1u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&siv); - msg[4..12].copy_from_slice(&u_server_finished_nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let gctr_block = msg; - - // AES-ECB encrypt a block with counter==2 and nonce from Server_Finished - let counter: [u8; 4] = 2u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&siv); - msg[4..12].copy_from_slice(&u_server_finished_nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let first_block = msg; - - // XOR MAC key and GCTR block with Notary's mask and then with User's mask - let mut mac_key_masked = [0u8; 16]; - xor(&mac_key, &n_output_mask1, &mut mac_key_masked); - xor( - &mac_key_masked.clone(), - &u_output_mask1, - &mut mac_key_masked, - ); - let mut gctr_block_masked = [0u8; 16]; - xor(&gctr_block, &n_output_mask2, &mut gctr_block_masked); - xor( - &gctr_block_masked.clone(), - &u_output_mask2, - &mut gctr_block_masked, - ); - - // XOR the first block and verify_data with User's mask - let mut first_block_masked = [0u8; 16]; - xor(&first_block, &u_output_mask3, &mut first_block_masked); - let mut verify_data_masked = [0u8; 12]; - xor(&verify_data, &u_output_mask4, &mut verify_data_masked); - - // Evaluate the circuit. - let outputs = evaluate_circuit( - &Circuit::load_bytes(CIRCUIT_5_BYTES).unwrap(), - vec![ - n_outer_hash_state_p1.to_vec(), - n_swk.to_vec(), - n_siv.to_vec(), - n_output_mask1.to_vec(), - n_output_mask2.to_vec(), - u_inner_hash_state_p1.to_vec(), - u_swk.to_vec(), - u_siv.to_vec(), - u_server_finished_nonce.to_vec(), - u_output_mask1.to_vec(), - u_output_mask2.to_vec(), - u_output_mask3.to_vec(), - u_output_mask4.to_vec(), - ], - vec![256, 128, 32, 128, 128, 256, 128, 32, 64, 128, 128, 128, 96], - vec![128, 128, 128, 96], - ); - assert_eq!(mac_key_masked.to_vec(), outputs[0]); - assert_eq!(gctr_block_masked.to_vec(), outputs[1]); - assert_eq!(first_block_masked.to_vec(), outputs[2]); - assert_eq!(verify_data_masked.to_vec(), outputs[3]); -} - -#[test] -// Tests correctness of the c6.casm circuit -fn circuit6() { - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_cwk: [u8; 16] = rng.gen(); - let n_civ: [u8; 4] = rng.gen(); - let u_cwk: [u8; 16] = rng.gen(); - let u_civ: [u8; 4] = rng.gen(); - let u_output_mask: [u8; 16] = rng.gen(); - let u_nonce: [u8; 2] = rng.gen(); - let u_counter: [u8; 2] = rng.gen(); - - // combine key shares - let mut cwk = [0u8; 16]; - xor(&n_cwk, &u_cwk, &mut cwk); - let mut civ = [0u8; 4]; - xor(&n_civ, &u_civ, &mut civ); - - // set AES key - let key = GenericArray::clone_from_slice(&cwk); - let cipher = Aes128::new(&key); - - // AES-ECB encrypt a block with counter and nonce from the User's input - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&civ); - // 54 msb of nonce must be zero - let mut nonce_bool = vec![false; 64]; - nonce_bool[54..64].copy_from_slice(&u8vec_to_boolvec(&u_nonce)[6..16]); - msg[4..12].copy_from_slice(&boolvec_to_u8vec(&nonce_bool)); - // 22 msb of counter must be zero - let mut counter_bool = vec![false; 32]; - counter_bool[22..32].copy_from_slice(&u8vec_to_boolvec(&u_counter)[6..16]); - msg[12..16].copy_from_slice(&boolvec_to_u8vec(&counter_bool)); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let encr_block = msg; - - // XOR-mask the encrypted block with User's mask - let mut encr_block_masked = [0u8; 16]; - xor(&encr_block, &u_output_mask, &mut encr_block_masked); - - // Evaluate the circuit. - let outputs = evaluate_circuit( - &Circuit::load_bytes(CIRCUIT_6_BYTES).unwrap(), - vec![ - n_cwk.to_vec(), - n_civ.to_vec(), - u_cwk.to_vec(), - u_civ.to_vec(), - u_output_mask.to_vec(), - u_nonce.to_vec(), - u_counter.to_vec(), - ], - vec![128, 32, 128, 32, 128, 10, 10], - vec![128], - ); - assert_eq!(encr_block_masked.to_vec(), outputs[0]); -} - -#[test] -// Tests correctness of the c7.casm circuit -fn circuit7() { - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_cwk: [u8; 16] = rng.gen(); - let n_civ: [u8; 4] = rng.gen(); - let n_output_mask: [u8; 16] = rng.gen(); - let u_cwk: [u8; 16] = rng.gen(); - let u_civ: [u8; 4] = rng.gen(); - let u_output_mask: [u8; 16] = rng.gen(); - let u_nonce: [u8; 2] = rng.gen(); - - // combine key shares - let mut cwk = [0u8; 16]; - xor(&n_cwk, &u_cwk, &mut cwk); - let mut civ = [0u8; 4]; - xor(&n_civ, &u_civ, &mut civ); - - // set AES key - let key = GenericArray::clone_from_slice(&cwk); - let cipher = Aes128::new(&key); - - // AES-ECB encrypt a block with counter==1 and nonce from the User's input - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&civ); - // 48 msb of nonce must be zero - let mut nonce_bool = vec![false; 64]; - nonce_bool[48..64].copy_from_slice(&u8vec_to_boolvec(&u_nonce)[0..16]); - msg[4..12].copy_from_slice(&boolvec_to_u8vec(&nonce_bool)); - let counter: [u8; 4] = 1u32.to_be_bytes(); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let gctr_block = msg; - - // XOR-mask the encrypted block with the Notary's mask and with the Client's mask - let mut gctr_block_masked = [0u8; 16]; - xor(&gctr_block, &u_output_mask, &mut gctr_block_masked); - xor( - &gctr_block_masked.clone(), - &n_output_mask, - &mut gctr_block_masked, - ); - - // Evaluate the circuit. - let outputs = evaluate_circuit( - &Circuit::load_bytes(CIRCUIT_7_BYTES).unwrap(), - vec![ - n_cwk.to_vec(), - n_civ.to_vec(), - n_output_mask.to_vec(), - u_cwk.to_vec(), - u_civ.to_vec(), - u_output_mask.to_vec(), - u_nonce.to_vec(), - ], - vec![128, 32, 128, 128, 32, 128, 16], - vec![128], - ); - assert_eq!(gctr_block_masked.to_vec(), outputs[0]); -} diff --git a/tls/tls-circuits/Cargo.toml b/tls/tls-circuits/Cargo.toml deleted file mode 100644 index 7c65d145e..000000000 --- a/tls/tls-circuits/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "tlsn-tls-circuits" -version = "0.1.0" -edition = "2021" - -[lib] -name = "tls_circuits" - -[features] -default = ["compile"] -compile = ["dep:rayon", "dep:prost"] - -[dependencies] -tlsn-mpc-circuits.workspace = true -rayon = { workspace = true, optional = true } -prost = { workspace = true, optional = true } - -[dev-dependencies] -generic-array.workspace = true -num-bigint = { workspace = true, features = ["rand"] } -num-traits.workspace = true -rand.workspace = true -sha2 = { workspace = true, features = ["compress"] } -aes.workspace = true diff --git a/tls/tls-circuits/src/c1.rs b/tls/tls-circuits/src/c1.rs deleted file mode 100644 index 1bfd53802..000000000 --- a/tls/tls-circuits/src/c1.rs +++ /dev/null @@ -1,302 +0,0 @@ -use std::sync::Arc; - -use crate::{combine_pms_shares, SHA256_STATE}; -use mpc_circuits::{ - builder::{map_le_bytes, CircuitBuilder}, - circuits::nbit_xor, - Circuit, ValueType, SHA_256, -}; - -/// TLS stage 1 -/// -/// Parties input their additive shares of the pre-master secret (PMS). -/// Outputs sha256(pms xor opad) called "pms outer hash state" to Notary and -/// also outputs sha256(pms xor ipad) called "pms inner hash state" to User. -/// -/// Inputs: -/// -/// 0. PMS_SHARE_A: 32-byte PMS Additive Share -/// 1. PMS_SHARE_B: 32-byte PMS Additive Share -/// 2. MASK_I: 32-byte mask for inner-state -/// 3. MASK_O: 32-byte mask for outer-state -/// -/// Outputs: -/// -/// 0. MASKED_I: 32-byte masked HMAC inner hash state -/// 1. MASKED_O: 32-byte masked HMAC outer hash state -pub fn c1() -> Arc { - let mut builder = CircuitBuilder::new("c1", "", "0.1.0"); - - let share_a = builder.add_input( - "PMS_SHARE_A", - "32-byte PMS Additive Share", - ValueType::Bytes, - 256, - ); - let share_b = builder.add_input( - "PMS_SHARE_B", - "32-byte PMS Additive Share", - ValueType::Bytes, - 256, - ); - let mask_inner = builder.add_input( - "MASK_I", - "32-byte mask for inner-state", - ValueType::Bytes, - 256, - ); - let mask_outer = builder.add_input( - "MASK_O", - "32-byte mask for outer-state", - ValueType::Bytes, - 256, - ); - let const_zero = builder.add_input( - "const_zero", - "input that is always 0", - ValueType::ConstZero, - 1, - ); - let const_one = builder.add_input( - "const_one", - "input that is always 1", - ValueType::ConstOne, - 1, - ); - - let mut builder = builder.build_inputs(); - - let sha256 = Circuit::load_bytes(SHA_256).expect("failed to load sha256 circuit"); - let xor_512_circ = nbit_xor(512); - let xor_256_circ = nbit_xor(256); - - let combine_pms = builder.add_circ(&combine_pms_shares()); - let sha256_ipad = builder.add_circ(&sha256); - let sha256_opad = builder.add_circ(&sha256); - let pms_ipad = builder.add_circ(&xor_512_circ); - let pms_opad = builder.add_circ(&xor_512_circ); - let masked_inner = builder.add_circ(&xor_256_circ); - let masked_outer = builder.add_circ(&xor_256_circ); - - builder.connect( - &share_a[..], - &combine_pms - .input(0) - .expect("combine_pms_shares missing input 0")[..], - ); - builder.connect( - &share_b[..], - &combine_pms - .input(1) - .expect("combine_pms_shares missing input 0")[..], - ); - builder.connect( - &const_zero[..], - &combine_pms - .input(2) - .expect("combine_pms_shares missing input 2")[..], - ); - builder.connect( - &const_one[..], - &combine_pms - .input(3) - .expect("combine_pms_shares missing input 3")[..], - ); - - let pms = combine_pms - .output(0) - .expect("combine_pms_shares missing output 0"); - - // inner - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &pms_ipad.input(0).expect("nbit_xor missing input 0")[..], - &[0x36u8; 64], - ); - builder.connect( - &pms[..], - &pms_ipad.input(1).expect("nbit_xor missing input 1")[256..], - ); - builder.connect( - &[const_zero[0]; 256], - &pms_ipad.input(1).expect("nbit_xor missing input 1")[..256], - ); - builder.connect( - &pms_ipad.output(0).expect("nbit_xor missing output 0")[..], - &sha256_ipad.input(0).expect("sha256 missing input 0")[..], - ); - // map SHA256 initial state - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_ipad.input(1).expect("sha256 missing input 1")[..], - &SHA256_STATE - .iter() - .rev() - .map(|chunk| chunk.to_le_bytes()) - .flatten() - .collect::>(), - ); - - // outer - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &pms_opad.input(0).expect("nbit_xor missing input 0")[..], - &[0x5cu8; 64], - ); - builder.connect( - &pms[..], - &pms_opad.input(1).expect("nbit_xor missing input 1")[256..], - ); - builder.connect( - &[const_zero[0]; 256], - &pms_opad.input(1).expect("nbit_xor missing input 1")[..256], - ); - builder.connect( - &pms_opad.output(0).expect("nbit_xor missing output 0")[..], - &sha256_opad.input(0).expect("sha256 missing input 0")[..], - ); - // map SHA256 initial state - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_opad.input(1).expect("sha256 missing input 1")[..], - &SHA256_STATE - .iter() - .rev() - .map(|chunk| chunk.to_le_bytes()) - .flatten() - .collect::>(), - ); - - // mask inner - builder.connect( - &sha256_ipad.output(0).expect("sha256 missing output 0")[..], - &masked_inner.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &mask_inner[..], - &masked_inner.input(1).expect("nbit_xor missing input 1")[..], - ); - - // mask outer - builder.connect( - &sha256_opad.output(0).expect("sha256 missing output 0")[..], - &masked_outer.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &mask_outer[..], - &masked_outer.input(1).expect("nbit_xor missing input 1")[..], - ); - - let mut builder = builder.build_gates(); - - let out_inner = builder.add_output( - "MASKED_I", - "32-byte masked HMAC inner hash state", - ValueType::Bytes, - 256, - ); - let out_outer = builder.add_output( - "MASKED_O", - "32-byte masked HMAC outer hash state", - ValueType::Bytes, - 256, - ); - - builder.connect( - &masked_inner.output(0).expect("nbit_xor missing output 0")[..], - &out_inner[..], - ); - builder.connect( - &masked_outer.output(0).expect("nbit_xor missing output 0")[..], - &out_outer[..], - ); - - builder.build_circuit().expect("failed to build c1") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::{partial_sha256_digest, test_circ}; - use mpc_circuits::Value; - use num_bigint::{BigUint, RandBigInt}; - use rand::{thread_rng, Rng}; - - /// NIST P-256 Prime - pub const P: &str = "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"; - - #[test] - #[ignore = "expensive"] - fn test_c1() { - let circ = c1(); - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let p = BigUint::parse_bytes(P.as_bytes(), 16).unwrap(); - let share_a = rng.gen_biguint_below(&p); - let share_b = rng.gen_biguint_below(&p); - - // * generate user's and notary's inside-the-GC-masks to mask the GC output - let mask_n: [u8; 32] = rng.gen(); - let mask_u: [u8; 32] = rng.gen(); - - // reduce pms mod prime if necessary - let pms = (share_a.clone() + share_b.clone()) % p; - - // * XOR pms (zero-padded to 64 bytes) with inner/outer padding of HMAC - let mut pms_zeropadded = [0u8; 64]; - pms_zeropadded[0..32].copy_from_slice(&pms.to_bytes_be()); - - let pms_ipad = pms_zeropadded.iter().map(|b| b ^ 0x36).collect::>(); - let pms_opad = pms_zeropadded.iter().map(|b| b ^ 0x5c).collect::>(); - - // * hash the padded PMS - let ohash_state = partial_sha256_digest(&pms_opad); - let ihash_state = partial_sha256_digest(&pms_ipad); - // convert into u8 array - let ohash_state_u8: Vec = ohash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - let ihash_state_u8: Vec = ihash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - - // * masked hash state are the expected circuit's outputs - let expected_inner = ihash_state_u8 - .into_iter() - .zip(mask_u) - .map(|(b, mask)| b ^ mask) - .collect::>(); - let expected_outer = ohash_state_u8 - .into_iter() - .zip(mask_n) - .map(|(b, mask)| b ^ mask) - .collect::>(); - - test_circ( - &circ, - &[ - Value::Bytes(share_a.to_bytes_le().to_vec()), - Value::Bytes(share_b.to_bytes_le().to_vec()), - Value::Bytes(mask_u.iter().rev().copied().collect::>()), - Value::Bytes(mask_n.iter().rev().copied().collect::>()), - ], - &[ - Value::Bytes(expected_inner.into_iter().rev().collect()), - Value::Bytes(expected_outer.into_iter().rev().collect()), - ], - ); - } -} diff --git a/tls/tls-circuits/src/c2.rs b/tls/tls-circuits/src/c2.rs deleted file mode 100644 index 6e35d7ade..000000000 --- a/tls/tls-circuits/src/c2.rs +++ /dev/null @@ -1,304 +0,0 @@ -use std::sync::Arc; - -use crate::SHA256_STATE; -use mpc_circuits::{ - builder::{map_le_bytes, CircuitBuilder}, - circuits::nbit_xor, - Circuit, ValueType, SHA_256, -}; - -/// TLS stage 2 -/// -/// Computes the master secret (MS). -/// Outputs sha256(ms xor opad) called "ms outer hash state" and -/// sha256(ms xor ipad) called "ms inner hash state" -/// -/// Inputs: -/// -/// 0. PMS_O_STATE: 32-byte PMS outer-hash state -/// 1. P1_INNER: 32-byte inner hash of P1 -/// 2. P2: 16-byte P2 -/// 3. MASK_I: 32-byte mask for inner-state -/// 4. MASK_O: 32-byte mask for outer-state -/// -/// Outputs: -/// -/// 0. MASKED_I: 32-byte masked HMAC inner hash state -/// 1. MASKED_O: 32-byte masked HMAC outer hash state -pub fn c2() -> Arc { - let mut builder = CircuitBuilder::new("c2", "", "0.1.0"); - - let pms_o = builder.add_input("PMS_O_STATE", "32-byte hash state", ValueType::Bytes, 256); - let p1_inner = builder.add_input("P1_INNER", "32-byte hash state", ValueType::Bytes, 256); - let p2 = builder.add_input("P2", "16-byte P2", ValueType::Bytes, 128); - let mask_inner = builder.add_input( - "MASK_I", - "32-byte mask for inner-state", - ValueType::Bytes, - 256, - ); - let mask_outer = builder.add_input( - "MASK_O", - "32-byte mask for outer-state", - ValueType::Bytes, - 256, - ); - let const_zero = builder.add_input( - "const_zero", - "input that is always 0", - ValueType::ConstZero, - 1, - ); - let const_one = builder.add_input( - "const_one", - "input that is always 1", - ValueType::ConstOne, - 1, - ); - - let mut builder = builder.build_inputs(); - - let sha256 = Circuit::load_bytes(SHA_256).expect("failed to load sha256 circuit"); - let xor_512_circ = nbit_xor(512); - let xor_256_circ = nbit_xor(256); - - let sha256_p1 = builder.add_circ(&sha256); - let sha256_ipad = builder.add_circ(&sha256); - let sha256_opad = builder.add_circ(&sha256); - let ms_ipad = builder.add_circ(&xor_512_circ); - let ms_opad = builder.add_circ(&xor_512_circ); - let masked_inner = builder.add_circ(&xor_256_circ); - let masked_outer = builder.add_circ(&xor_256_circ); - - // p1 - let sha256_p1_msg = sha256_p1.input(0).expect("sha256 missing input 1"); - builder.connect(&p1_inner[..], &sha256_p1_msg[256..]); - // append a single '1' bit - builder.connect(&[const_one[0]], &[sha256_p1_msg[255]]); - // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + 64) is a multiple of 512 - builder.connect(&[const_zero[0]; 239], &sha256_p1_msg[16..255]); - // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits - // L = 768 = 0x0300 - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_p1_msg[..16], - &[0x00, 0x03], - ); - builder.connect( - &pms_o[..], - &sha256_p1.input(1).expect("sha256 missing input 1")[..], - ); - - let p1 = sha256_p1.output(0).expect("sha256 missing output 0"); - - // inner - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &ms_ipad.input(0).expect("nbit_xor missing input 0")[..], - &[0x36u8; 64], - ); - builder.connect( - &p1[..], - &ms_ipad.input(1).expect("nbit_xor missing input 1")[256..], - ); - builder.connect( - &p2[..], - &ms_ipad.input(1).expect("nbit_xor missing input 1")[128..256], - ); - builder.connect( - &[const_zero[0]; 128], - &ms_ipad.input(1).expect("nbit_xor missing input 1")[..128], - ); - builder.connect( - &ms_ipad.output(0).expect("nbit_xor missing output 0")[..], - &sha256_ipad.input(0).expect("sha256 missing input 0")[..], - ); - // map SHA256 initial state - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_ipad.input(1).expect("sha256 missing input 1")[..], - &SHA256_STATE - .iter() - .rev() - .map(|chunk| chunk.to_le_bytes()) - .flatten() - .collect::>(), - ); - - // outer - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &ms_opad.input(0).expect("nbit_xor missing input 0")[..], - &[0x5cu8; 64], - ); - builder.connect( - &p1[..], - &ms_opad.input(1).expect("nbit_xor missing input 1")[256..], - ); - builder.connect( - &p2[..], - &ms_opad.input(1).expect("nbit_xor missing input 1")[128..256], - ); - builder.connect( - &[const_zero[0]; 128], - &ms_opad.input(1).expect("nbit_xor missing input 1")[..128], - ); - builder.connect( - &ms_opad.output(0).expect("nbit_xor missing output 0")[..], - &sha256_opad.input(0).expect("sha256 missing input 0")[..], - ); - // map SHA256 initial state - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_opad.input(1).expect("sha256 missing input 1")[..], - &SHA256_STATE - .iter() - .rev() - .map(|chunk| chunk.to_le_bytes()) - .flatten() - .collect::>(), - ); - - // mask inner - builder.connect( - &sha256_ipad.output(0).expect("sha256 missing output 0")[..], - &masked_inner.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &mask_inner[..], - &masked_inner.input(1).expect("nbit_xor missing input 1")[..], - ); - - // mask outer - builder.connect( - &sha256_opad.output(0).expect("sha256 missing output 0")[..], - &masked_outer.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &mask_outer[..], - &masked_outer.input(1).expect("nbit_xor missing input 1")[..], - ); - - let mut builder = builder.build_gates(); - - let out_inner = builder.add_output( - "MASKED_I", - "32-byte masked HMAC inner hash state", - ValueType::Bytes, - 256, - ); - let out_outer = builder.add_output( - "MASKED_O", - "32-byte masked HMAC outer hash state", - ValueType::Bytes, - 256, - ); - - builder.connect( - &masked_inner.output(0).expect("nbit_xor missing output 0")[..], - &out_inner[..], - ); - builder.connect( - &masked_outer.output(0).expect("nbit_xor missing output 0")[..], - &out_outer[..], - ); - - builder.build_circuit().expect("failed to build c2") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::{finalize_sha256_digest, partial_sha256_digest, test_circ}; - use mpc_circuits::Value; - use rand::{thread_rng, Rng}; - - #[test] - #[ignore = "expensive"] - fn test_c2() { - let circ = c2(); - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_outer_hash_state: [u32; 8] = rng.gen(); - let u_inner_hash_p1: [u8; 32] = rng.gen(); - let u_p2: [u8; 16] = rng.gen(); - - // * generate user's and notary's inside-the-GC-masks to mask the GC output - let mask_n: [u8; 32] = rng.gen(); - let mask_u: [u8; 32] = rng.gen(); - - // finalize the hash to get p1 - let p1 = finalize_sha256_digest(n_outer_hash_state, 64, &u_inner_hash_p1); - // get master_secret - let mut ms = [0u8; 48]; - ms[..32].copy_from_slice(&p1); - ms[32..48].copy_from_slice(&u_p2[..16]); - - // * XOR ms (zero-padded to 64 bytes) with inner/outer padding of HMAC - let mut ms_zeropadded = [0u8; 64]; - ms_zeropadded[0..48].copy_from_slice(&ms); - - let pms_ipad = ms_zeropadded.iter().map(|b| b ^ 0x36).collect::>(); - let pms_opad = ms_zeropadded.iter().map(|b| b ^ 0x5c).collect::>(); - - // * hash the padded PMS - let ohash_state = partial_sha256_digest(&pms_opad); - let ihash_state = partial_sha256_digest(&pms_ipad); - // convert into u8 array - let ohash_state_u8: Vec = ohash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - let ihash_state_u8: Vec = ihash_state - .iter() - .map(|u32t| u32t.to_be_bytes()) - .flatten() - .collect(); - - // * masked hash state are the expected circuit's outputs - let expected_inner = ihash_state_u8 - .into_iter() - .zip(mask_u) - .map(|(b, mask)| b ^ mask) - .collect::>(); - let expected_outer = ohash_state_u8 - .into_iter() - .zip(mask_n) - .map(|(b, mask)| b ^ mask) - .collect::>(); - - test_circ( - &circ, - &[ - Value::Bytes( - n_outer_hash_state - .into_iter() - .rev() - .map(|v| v.to_le_bytes()) - .flatten() - .collect::>(), - ), - Value::Bytes(u_inner_hash_p1.into_iter().rev().collect()), - Value::Bytes(u_p2.into_iter().rev().collect()), - Value::Bytes(mask_u.iter().rev().copied().collect::>()), - Value::Bytes(mask_n.iter().rev().copied().collect::>()), - ], - &[ - Value::Bytes(expected_inner.into_iter().rev().collect()), - Value::Bytes(expected_outer.into_iter().rev().collect()), - ], - ); - } -} diff --git a/tls/tls-circuits/src/c3.rs b/tls/tls-circuits/src/c3.rs deleted file mode 100644 index 8462dd61a..000000000 --- a/tls/tls-circuits/src/c3.rs +++ /dev/null @@ -1,406 +0,0 @@ -use std::sync::Arc; - -use mpc_circuits::{ - builder::{map_le_bytes, CircuitBuilder}, - circuits::nbit_xor, - Circuit, ValueType, SHA_256, -}; - -/// TLS stage 3 -/// -/// Compute expanded p1 which consists of client_write_key + server_write_key -/// Compute expanded p2 which consists of client_IV + server_IV -/// -/// Inputs: -/// -/// 0. OUTER_HASH_STATE: 32-byte outer-hash state -/// 1. N_CWK_MASK: 16-byte mask for client write-key -/// 2. N_SWK_MASK: 16-byte mask for server write-key -/// 3. N_CIV_MASK: 4-byte mask for client IV -/// 4. N_SIV_MASK: 4-byte mask for server IV -/// 5. P1_INNER: 32-byte inner hash for p1_expanded_keys -/// 6. P2_INNER: 32-byte inner hash for p2_expanded_keys -/// 7. U_CWK_MASK: 16-byte mask for client write-key -/// 8. U_SWK_MASK: 16-byte mask for server write-key -/// 9. U_CIV_MASK: 4-byte mask for client IV -/// 10. U_SIV_MASK: 4-byte mask for server IV -/// -/// Outputs: -/// -/// 0. MASKED_CWK: 16-byte masked (N_CWK_MASK + U_CWK_MASK) client write-key -/// 1. MASKED_SWK: 16-byte masked (N_SWK_MASK + U_SWK_MASK) server write-key -/// 2. MASKED_CIV: 4-byte masked (N_CIV_MASK + U_CIV_MASK) client IV -/// 3. MASKED_SIV: 4-byte masked (N_SIV_MASK + U_SIV_MASK) server IV -pub fn c3() -> Arc { - let mut builder = CircuitBuilder::new("c3", "", "0.1.0"); - - let outer_state = builder.add_input( - "OUTER_HASH_STATE", - "32-byte hash state", - ValueType::Bytes, - 256, - ); - let n_cwk_mask = builder.add_input( - "N_CWK_MASK", - "16-byte mask for client write-key", - ValueType::Bytes, - 128, - ); - let n_swk_mask = builder.add_input( - "N_SWK_MASK", - "16-byte mask for server write-key", - ValueType::Bytes, - 128, - ); - let n_civ_mask = builder.add_input( - "N_CIV_MASK", - "4-byte mask for client IV", - ValueType::Bytes, - 32, - ); - let n_siv_mask = builder.add_input( - "N_SIV_MASK", - "4-byte mask for server IV", - ValueType::Bytes, - 32, - ); - let p1_hash = builder.add_input( - "P1_INNER", - "32-byte inner hash for p1_expanded_keys", - ValueType::Bytes, - 256, - ); - let p2_hash = builder.add_input( - "P2_INNER", - "32-byte inner hash for p2_expanded_keys", - ValueType::Bytes, - 256, - ); - let u_cwk_mask = builder.add_input( - "U_CWK_MASK", - "16-byte mask for client write-key", - ValueType::Bytes, - 128, - ); - let u_swk_mask = builder.add_input( - "U_SWK_MASK", - "16-byte mask for server write-key", - ValueType::Bytes, - 128, - ); - let u_civ_mask = builder.add_input( - "U_CIV_MASK", - "4-byte mask for client IV", - ValueType::Bytes, - 32, - ); - let u_siv_mask = builder.add_input( - "U_SIV_MASK", - "4-byte mask for server IV", - ValueType::Bytes, - 32, - ); - let const_zero = builder.add_input( - "const_zero", - "input that is always 0", - ValueType::ConstZero, - 1, - ); - let const_one = builder.add_input( - "const_one", - "input that is always 1", - ValueType::ConstOne, - 1, - ); - - let mut builder = builder.build_inputs(); - - let sha256 = Circuit::load_bytes(SHA_256).expect("failed to load sha256 circuit"); - let xor_128_circ = nbit_xor(128); - let xor_32_circ = nbit_xor(32); - - let sha256_p1 = builder.add_circ(&sha256); - let sha256_p2 = builder.add_circ(&sha256); - let mask_cwk = builder.add_circ(&xor_128_circ); - let mask_swk = builder.add_circ(&xor_128_circ); - let mask_civ = builder.add_circ(&xor_32_circ); - let mask_siv = builder.add_circ(&xor_32_circ); - let masked_cwk = builder.add_circ(&xor_128_circ); - let masked_swk = builder.add_circ(&xor_128_circ); - let masked_civ = builder.add_circ(&xor_32_circ); - let masked_siv = builder.add_circ(&xor_32_circ); - - // p1 - let sha256_p1_msg = sha256_p1.input(0).expect("sha256 missing input 0"); - builder.connect(&p1_hash[..], &sha256_p1_msg[256..]); - // append a single '1' bit - builder.connect(&[const_one[0]], &[sha256_p1_msg[255]]); - // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + 64) is a multiple of 512 - builder.connect(&[const_zero[0]; 239], &sha256_p1_msg[16..255]); - // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits - // L = 768 = 0x0300 - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_p1_msg[..16], - &[0x00, 0x03], - ); - builder.connect( - &outer_state[..], - &sha256_p1.input(1).expect("sha256 missing input 1")[..], - ); - - let p1 = sha256_p1.output(0).expect("sha256 missing output 0"); - - // p2 - let sha256_p2_msg = sha256_p2.input(0).expect("sha256 missing input 0"); - builder.connect(&p2_hash[..], &sha256_p2_msg[256..]); - // append a single '1' bit - builder.connect(&[const_one[0]], &[sha256_p2_msg[255]]); - // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + 64) is a multiple of 512 - builder.connect(&[const_zero[0]; 239], &sha256_p2_msg[16..255]); - // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits - // L = 768 = 0x0300 - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_p2_msg[..16], - &[0x00, 0x03], - ); - builder.connect( - &outer_state[..], - &sha256_p2.input(1).expect("sha256 missing input 1")[..], - ); - - let p2 = sha256_p2.output(0).expect("sha256 missing output 0"); - - // cwk mask - builder.connect( - &n_cwk_mask[..], - &mask_cwk.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_cwk_mask[..], - &mask_cwk.input(1).expect("nbit_xor missing input 1")[..], - ); - - // swk mask - builder.connect( - &n_swk_mask[..], - &mask_swk.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_swk_mask[..], - &mask_swk.input(1).expect("nbit_xor missing input 1")[..], - ); - - // civ mask - builder.connect( - &n_civ_mask[..], - &mask_civ.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_civ_mask[..], - &mask_civ.input(1).expect("nbit_xor missing input 1")[..], - ); - - // siv mask - builder.connect( - &n_siv_mask[..], - &mask_siv.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_siv_mask[..], - &mask_siv.input(1).expect("nbit_xor missing input 1")[..], - ); - - // apply cwk mask - builder.connect( - &mask_cwk.output(0).expect("nbit_xor missing output 0")[..], - &masked_cwk.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &p1[128..], - &masked_cwk.input(1).expect("nbit_xor missing input 1")[..], - ); - - // apply swk mask - builder.connect( - &mask_swk.output(0).expect("nbit_xor missing output 0")[..], - &masked_swk.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &p1[..128], - &masked_swk.input(1).expect("nbit_xor missing input 1")[..], - ); - - // apply civ mask - builder.connect( - &mask_civ.output(0).expect("nbit_xor missing output 0")[..], - &masked_civ.input(0).expect("nbit_xor missing input 1")[..], - ); - builder.connect( - &p2[224..], - &masked_civ.input(1).expect("nbit_xor missing input 1")[..], - ); - - // apply siv mask - builder.connect( - &mask_siv.output(0).expect("nbit_xor missing output 0")[..], - &masked_siv.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &p2[192..224], - &masked_siv.input(1).expect("nbit_xor missing input 1")[..], - ); - - let mut builder = builder.build_gates(); - - let cwk = builder.add_output( - "MASKED_CWK", - "16-byte masked client write-key", - ValueType::Bytes, - 128, - ); - let swk = builder.add_output( - "MASKED_SWK", - "16-byte masked server write-key", - ValueType::Bytes, - 128, - ); - let civ = builder.add_output( - "MASKED_CIV", - "4-byte masked client IV", - ValueType::Bytes, - 32, - ); - let siv = builder.add_output( - "MASKED_SIV", - "4-byte masked server IV", - ValueType::Bytes, - 32, - ); - - builder.connect( - &masked_cwk.output(0).expect("nbit_xor missing output 0")[..], - &cwk[..], - ); - builder.connect( - &masked_swk.output(0).expect("nbit_xor missing output 0")[..], - &swk[..], - ); - builder.connect( - &masked_civ.output(0).expect("nbit_xor missing output 0")[..], - &civ[..], - ); - builder.connect( - &masked_siv.output(0).expect("nbit_xor missing output 0")[..], - &siv[..], - ); - - builder.build_circuit().expect("failed to build c3") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::{finalize_sha256_digest, test_circ}; - use mpc_circuits::Value; - use rand::{thread_rng, Rng}; - - #[test] - #[ignore = "expensive"] - fn test_c3() { - let circ = c3(); - // Perform in the clear all the computations which happen inside the ciruit: - let mut rng = thread_rng(); - - let n_outer_hash_state: [u32; 8] = rng.gen(); - let n_cwk_mask: [u8; 16] = rng.gen(); - let n_swk_mask: [u8; 16] = rng.gen(); - let n_civ_mask: [u8; 4] = rng.gen(); - let n_siv_mask: [u8; 4] = rng.gen(); - let u_inner_hash_p1: [u8; 32] = rng.gen(); - let u_inner_hash_p2: [u8; 32] = rng.gen(); - let u_cwk_mask: [u8; 16] = rng.gen(); - let u_swk_mask: [u8; 16] = rng.gen(); - let u_civ_mask: [u8; 4] = rng.gen(); - let u_siv_mask: [u8; 4] = rng.gen(); - - // finalize the hash to get p1 - let p1 = finalize_sha256_digest(n_outer_hash_state, 64, &u_inner_hash_p1); - // finalize the hash to get p2 - let p2 = finalize_sha256_digest(n_outer_hash_state, 64, &u_inner_hash_p2); - - // get expanded_keys (TLS session keys) - let mut ek = [0u8; 40]; - ek[..32].copy_from_slice(&p1); - ek[32..40].copy_from_slice(&p2[0..8]); - // split into client/server_write_key and client/server_write_iv - let mut cwk = [0u8; 16]; - cwk.copy_from_slice(&ek[0..16]); - let mut swk = [0u8; 16]; - swk.copy_from_slice(&ek[16..32]); - let mut civ = [0u8; 4]; - civ.copy_from_slice(&ek[32..36]); - let mut siv = [0u8; 4]; - siv.copy_from_slice(&ek[36..40]); - - let cwk_masked = cwk - .iter() - .zip(n_cwk_mask) - .zip(u_cwk_mask) - .map(|((v, n_mask), u_mask)| v ^ n_mask ^ u_mask) - .collect::>(); - let swk_masked = swk - .iter() - .zip(n_swk_mask) - .zip(u_swk_mask) - .map(|((v, n_mask), u_mask)| v ^ n_mask ^ u_mask) - .collect::>(); - let civ_masked = civ - .iter() - .zip(n_civ_mask) - .zip(u_civ_mask) - .map(|((v, n_mask), u_mask)| v ^ n_mask ^ u_mask) - .collect::>(); - let siv_masked = siv - .iter() - .zip(n_siv_mask) - .zip(u_siv_mask) - .map(|((v, n_mask), u_mask)| v ^ n_mask ^ u_mask) - .collect::>(); - - test_circ( - &circ, - &[ - Value::Bytes( - n_outer_hash_state - .into_iter() - .rev() - .map(|v| v.to_le_bytes()) - .flatten() - .collect::>(), - ), - Value::Bytes(n_cwk_mask.iter().rev().copied().collect::>()), - Value::Bytes(n_swk_mask.iter().rev().copied().collect::>()), - Value::Bytes(n_civ_mask.iter().rev().copied().collect::>()), - Value::Bytes(n_siv_mask.iter().rev().copied().collect::>()), - Value::Bytes(u_inner_hash_p1.into_iter().rev().collect()), - Value::Bytes(u_inner_hash_p2.into_iter().rev().collect()), - Value::Bytes(u_cwk_mask.iter().rev().copied().collect::>()), - Value::Bytes(u_swk_mask.iter().rev().copied().collect::>()), - Value::Bytes(u_civ_mask.iter().rev().copied().collect::>()), - Value::Bytes(u_siv_mask.iter().rev().copied().collect::>()), - ], - &[ - Value::Bytes(cwk_masked.into_iter().rev().collect()), - Value::Bytes(swk_masked.into_iter().rev().collect()), - Value::Bytes(civ_masked.into_iter().rev().collect()), - Value::Bytes(siv_masked.into_iter().rev().collect()), - ], - ); - } -} diff --git a/tls/tls-circuits/src/c4.rs b/tls/tls-circuits/src/c4.rs deleted file mode 100644 index c6212eded..000000000 --- a/tls/tls-circuits/src/c4.rs +++ /dev/null @@ -1,382 +0,0 @@ -use std::sync::Arc; - -use mpc_circuits::{ - builder::{map_le_bytes, CircuitBuilder}, - circuits::nbit_xor, - Circuit, ValueType, AES_128_REVERSE, -}; - -/// TLS stage 4 -/// -/// Compute ghash H, gctr block, encrypted counter block - needed for Client Finished -/// -/// Inputs: -/// -/// 0. N_CWK: 16-byte Notary share of client write-key -/// 1. N_CIV: 4-byte Notary share of client IV -/// 2. N_H_MASK: 16-byte Notary mask for H -/// 3. N_GCTR_MASK: 16-byte Notary mask for GCTR -/// 4. U_CWK: 16-byte User share of client write-key -/// 5. U_CIV: 4-byte User share of client IV -/// 6. U_H_MASK: 16-byte User mask for H -/// 7. U_GCTR_MASK: 16-byte User mask for GCTR -/// 8. U_ECTR_MASK: 16-byte User mask for encrypted counter -/// -/// Outputs: -/// -/// 0. MASKED_H: 16-byte masked (N_H_MASK + U_H_MASK) H -/// 1. MASKED_GCTR: 16-byte masked (N_GCTR_MASK + U_GCTR_MASK) GCTR -/// 2. MASKED_ECTR: 16-byte masked (U_ECTR_MASK) encrypted counter -pub fn c4() -> Arc { - let mut builder = CircuitBuilder::new("c4", "", "0.1.0"); - - let n_cwk = builder.add_input( - "N_CWK", - "16-byte Notary share of client write-key", - ValueType::Bytes, - 128, - ); - let n_civ = builder.add_input( - "N_CIV", - "4-byte Notary share of client IV", - ValueType::Bytes, - 32, - ); - let n_h_mask = builder.add_input( - "N_H_MASK", - "16-byte Notary mask for H", - ValueType::Bytes, - 128, - ); - let n_gctr_mask = builder.add_input( - "N_GCTR_MASK", - "16-byte Notary mask for GCTR", - ValueType::Bytes, - 128, - ); - let u_cwk = builder.add_input( - "U_CWK", - "16-byte User share of client write-key", - ValueType::Bytes, - 128, - ); - let u_civ = builder.add_input( - "U_CIV", - "4-byte User share of client IV", - ValueType::Bytes, - 32, - ); - let u_h_mask = builder.add_input("U_H_MASK", "16-byte User mask for H", ValueType::Bytes, 128); - let u_gctr_mask = builder.add_input( - "U_GCTR_MASK", - "16-byte User mask for GCTR", - ValueType::Bytes, - 128, - ); - let u_ectr_mask = builder.add_input( - "U_ECTR_MASK", - "16-byte User mask for ECTR", - ValueType::Bytes, - 128, - ); - let const_zero = builder.add_input( - "const_zero", - "input that is always 0", - ValueType::ConstZero, - 1, - ); - let const_one = builder.add_input( - "const_one", - "input that is always 1", - ValueType::ConstOne, - 1, - ); - - let mut builder = builder.build_inputs(); - - let aes = Circuit::load_bytes(AES_128_REVERSE).expect("failed to load aes_128_reverse circuit"); - let xor_128_circ = nbit_xor(128); - let xor_32_circ = nbit_xor(32); - - let aes_h = builder.add_circ(&aes); - let aes_gctr = builder.add_circ(&aes); - let aes_ectr = builder.add_circ(&aes); - let cwk = builder.add_circ(&xor_128_circ); - let civ = builder.add_circ(&xor_32_circ); - let mask_h = builder.add_circ(&xor_128_circ); - let mask_gctr = builder.add_circ(&xor_128_circ); - let masked_h = builder.add_circ(&xor_128_circ); - let masked_gctr = builder.add_circ(&xor_128_circ); - let masked_ectr = builder.add_circ(&xor_128_circ); - - // cwk - builder.connect( - &n_cwk[..], - &cwk.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_cwk[..], - &cwk.input(1).expect("nbit_xor missing input 1")[..], - ); - let cwk = cwk.output(0).expect("nbit_xor missing output 0"); - - // civ - builder.connect( - &n_civ[..], - &civ.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_civ[..], - &civ.input(1).expect("nbit_xor missing input 1")[..], - ); - let civ = civ.output(0).expect("nbit_xor missing output 0"); - - // Compute H - builder.connect(&cwk[..], &aes_h.input(0).expect("aes missing input 0")[..]); - // encrypt all zeroes - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &aes_h.input(1).expect("aes missing input 1")[..], - &[0u8; 16], - ); - let h = aes_h.output(0).expect("aes missing output 0"); - - // Compute GCTR - builder.connect( - &cwk[..], - &aes_gctr.input(0).expect("aes missing input 0")[..], - ); - let aes_gctr_m = aes_gctr.input(1).expect("aes missing input 1"); - builder.connect(&civ[..], &aes_gctr_m[96..]); - // Nonce (0x1) + CTR (0x1) - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &aes_gctr_m[..96], - &[ - 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ], - ); - let gctr = aes_gctr.output(0).expect("aes missing output 0"); - - // Compute ECTR - builder.connect( - &cwk[..], - &aes_ectr.input(0).expect("aes missing input 0")[..], - ); - let aes_ectr_m = aes_ectr.input(1).expect("aes missing input 1"); - builder.connect(&civ[..], &aes_ectr_m[96..]); - // Nonce (0x1) + CTR (0x2) - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &aes_ectr_m[..96], - &[ - 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ], - ); - let ectr = aes_ectr.output(0).expect("aes missing output 0"); - - // H mask - builder.connect( - &n_h_mask[..], - &mask_h.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_h_mask[..], - &mask_h.input(1).expect("nbit_xor missing input 1")[..], - ); - let mask_h = mask_h.output(0).expect("nbit_xor missing output 0"); - - // GCTR mask - builder.connect( - &n_gctr_mask[..], - &mask_gctr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_gctr_mask[..], - &mask_gctr.input(1).expect("nbit_xor missing input 1")[..], - ); - let mask_gctr = mask_gctr.output(0).expect("nbit_xor missing output 0"); - - // Apply H mask - builder.connect( - &mask_h[..], - &masked_h.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &h[..], - &masked_h.input(1).expect("nbit_xor missing input 0")[..], - ); - - // Apply GCTR mask - builder.connect( - &mask_gctr[..], - &masked_gctr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &gctr[..], - &masked_gctr.input(1).expect("nbit_xor missing input 0")[..], - ); - - // Apply ECTR mask - builder.connect( - &ectr[..], - &masked_ectr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_ectr_mask[..], - &masked_ectr.input(1).expect("nbit_xor missing input 0")[..], - ); - - let mut builder = builder.build_gates(); - - let out_h = builder.add_output( - "MASKED_H", - "16-byte masked (N_H_MASK + U_H_MASK) H", - ValueType::Bytes, - 128, - ); - let out_gctr = builder.add_output( - "MASKED_GCTR", - "16-byte masked (N_GCTR_MASK + U_GCTR_MASK) GCTR", - ValueType::Bytes, - 128, - ); - let out_ectr = builder.add_output( - "MASKED_ECTR", - "16-byte masked (U_ECTR_MASK) encrypted counter", - ValueType::Bytes, - 128, - ); - - builder.connect( - &masked_h.output(0).expect("nbit_xor missing output 0")[..], - &out_h[..], - ); - builder.connect( - &masked_gctr.output(0).expect("nbit_xor missing output 0")[..], - &out_gctr[..], - ); - builder.connect( - &masked_ectr.output(0).expect("nbit_xor missing output 0")[..], - &out_ectr[..], - ); - - builder.build_circuit().expect("failed to build c4") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::test_circ; - use aes::{Aes128, BlockEncrypt, NewBlockCipher}; - use generic_array::GenericArray; - use mpc_circuits::Value; - use rand::{thread_rng, Rng}; - - #[test] - #[ignore = "expensive"] - fn test_c4() { - let circ = c4(); - - let mut rng = thread_rng(); - - let n_cwk: [u8; 16] = rng.gen(); - let n_civ: [u8; 4] = rng.gen(); - let n_h_mask: [u8; 16] = rng.gen(); - let n_gctr_mask: [u8; 16] = rng.gen(); - let u_cwk: [u8; 16] = rng.gen(); - let u_civ: [u8; 4] = rng.gen(); - let u_h_mask: [u8; 16] = rng.gen(); - let u_gctr_mask: [u8; 16] = rng.gen(); - let u_ectr_mask: [u8; 16] = rng.gen(); - - // combine key shares - let cwk = n_cwk - .iter() - .zip(u_cwk) - .map(|(n, u)| n ^ u) - .collect::>(); - let civ = n_civ - .iter() - .zip(u_civ) - .map(|(n, u)| n ^ u) - .collect::>(); - - // set AES key - let cipher = Aes128::new_from_slice(&cwk).unwrap(); - - // AES-ECB encrypt 0, get MAC key - let mut z = GenericArray::clone_from_slice(&[0u8; 16]); - cipher.encrypt_block(&mut z); - let mac_key = z; - - // AES-ECB encrypt a block with counter==1 and nonce==1, get GCTR block - let nonce: [u8; 8] = 1u64.to_be_bytes(); - let counter: [u8; 4] = 1u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&civ); - msg[4..12].copy_from_slice(&nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let gctr_block = msg; - - // AES-ECB encrypt a block with counter==2 and nonce==1 - let nonce: [u8; 8] = 1u64.to_be_bytes(); - let counter: [u8; 4] = 2u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&civ); - msg[4..12].copy_from_slice(&nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let first_block = msg; - - // XOR MAC key and GCTR block with Notary's mask and then with User's mask - let mac_key_masked = mac_key - .iter() - .zip(n_h_mask) - .zip(u_h_mask) - .map(|((v, n), u)| v ^ n ^ u) - .collect::>(); - let gctr_block_masked = gctr_block - .iter() - .zip(n_gctr_mask) - .zip(u_gctr_mask) - .map(|((v, n), u)| v ^ n ^ u) - .collect::>(); - - // XOR the first block with User's mask - let first_block_masked = first_block - .iter() - .zip(u_ectr_mask) - .map(|(v, u)| v ^ u) - .collect::>(); - - test_circ( - &circ, - &[ - Value::Bytes(n_cwk.into_iter().rev().collect()), - Value::Bytes(n_civ.into_iter().rev().collect()), - Value::Bytes(n_h_mask.into_iter().rev().collect()), - Value::Bytes(n_gctr_mask.into_iter().rev().collect()), - Value::Bytes(u_cwk.into_iter().rev().collect()), - Value::Bytes(u_civ.into_iter().rev().collect()), - Value::Bytes(u_h_mask.into_iter().rev().collect()), - Value::Bytes(u_gctr_mask.into_iter().rev().collect()), - Value::Bytes(u_ectr_mask.into_iter().rev().collect()), - ], - &[ - Value::Bytes(mac_key_masked.into_iter().rev().collect()), - Value::Bytes(gctr_block_masked.into_iter().rev().collect()), - Value::Bytes(first_block_masked.into_iter().rev().collect()), - ], - ); - } -} diff --git a/tls/tls-circuits/src/c5.rs b/tls/tls-circuits/src/c5.rs deleted file mode 100644 index 6083b9cd6..000000000 --- a/tls/tls-circuits/src/c5.rs +++ /dev/null @@ -1,480 +0,0 @@ -use std::sync::Arc; - -use mpc_circuits::{ - builder::{map_le_bytes, CircuitBuilder}, - circuits::nbit_xor, - Circuit, ValueType, AES_128_REVERSE, SHA_256, -}; - -/// TLS stage 5 -/// -/// Compute ghash H, gctr block, encrypted counter block, verify_data - needed for Server Finished -/// -/// Inputs: -/// -/// 0. P1_OUTER_STATE: 32-byte outer hash state for P1 -/// 1. N_SWK: 16-byte Notary share of server write-key -/// 2. N_SIV: 4-byte Notary share of server IV -/// 3. N_H_MASK: 16-byte Notary mask for H -/// 4. N_GCTR_MASK: 16-byte Notary mask for GCTR -/// 5. P1_INNER_STATE: 32-byte inner hash for P1 -/// 6. U_SWK: 16-byte User share of server write-key -/// 7. U_SIV: 4-byte User share of server IV -/// 8. NONCE: 8-byte server_finished nonce -/// 9. U_H_MASK: 16-byte User mask for H -/// 10. U_GCTR_MASK: 16-byte User mask for GCTR -/// 11. U_ECTR_MASK: 16-byte User mask for encrypted counter -/// 12. U_VD_MASK: 12-byte User mask for server verify data -/// -/// Outputs: -/// -/// 0. MASKED_H: 16-byte masked (N_H_MASK + U_H_MASK) H -/// 1. MASKED_GCTR: 16-byte masked (N_GCTR_MASK + U_GCTR_MASK) GCTR -/// 2. MASKED_ECTR: 16-byte masked (U_ECTR_MASK) encrypted counter -/// 3. MASKED_VD: 12-byte masked (U_VD_MASK) server verify data -pub fn c5() -> Arc { - let mut builder = CircuitBuilder::new("c5", "", "0.1.0"); - - let p1_outer_state = builder.add_input( - "P1_OUTER_STATE", - "32-byte outer hash state for P1", - ValueType::Bytes, - 256, - ); - let n_swk = builder.add_input( - "N_SWK", - "16-byte Notary server write-key share", - ValueType::Bytes, - 128, - ); - let n_siv = builder.add_input( - "N_SIV", - "4-byte Notary share of server IV", - ValueType::Bytes, - 32, - ); - let n_h_mask = builder.add_input( - "N_H_MASK", - "16-byte Notary mask for H", - ValueType::Bytes, - 128, - ); - let n_gctr_mask = builder.add_input( - "N_GCTR_MASK", - "16-byte Notary mask for GCTR", - ValueType::Bytes, - 128, - ); - let p1_inner_hash = builder.add_input( - "P1_INNER_STATE", - "32-byte inner hash for P1", - ValueType::Bytes, - 256, - ); - let u_swk = builder.add_input( - "U_SWK", - "16-byte User share of server write-key", - ValueType::Bytes, - 128, - ); - let u_siv = builder.add_input( - "U_SIV", - "4-byte User share of server IV", - ValueType::Bytes, - 32, - ); - let nonce = builder.add_input( - "NONCE", - "8-byte server_finished nonce", - ValueType::Bytes, - 64, - ); - let u_h_mask = builder.add_input("U_H_MASK", "16-byte User mask for H", ValueType::Bytes, 128); - let u_gctr_mask = builder.add_input( - "U_GCTR_MASK", - "16-byte User mask for GCTR", - ValueType::Bytes, - 128, - ); - let u_ectr_mask = builder.add_input( - "U_ECTR_MASK", - "16-byte User mask for ECTR", - ValueType::Bytes, - 128, - ); - let u_vd_mask = builder.add_input( - "U_VD_MASK", - "12-byte User mask for server verify data", - ValueType::Bytes, - 96, - ); - let const_zero = builder.add_input( - "const_zero", - "input that is always 0", - ValueType::ConstZero, - 1, - ); - let const_one = builder.add_input( - "const_one", - "input that is always 1", - ValueType::ConstOne, - 1, - ); - - let mut builder = builder.build_inputs(); - - let aes = Circuit::load_bytes(AES_128_REVERSE).expect("failed to load aes_128_reverse circuit"); - let sha256 = Circuit::load_bytes(SHA_256).expect("failed to load sha_256 circuit"); - let xor_128_circ = nbit_xor(128); - let xor_96_circ = nbit_xor(96); - let xor_32_circ = nbit_xor(32); - - let sha256_p1 = builder.add_circ(&sha256); - let aes_h = builder.add_circ(&aes); - let aes_gctr = builder.add_circ(&aes); - let aes_ectr = builder.add_circ(&aes); - let swk = builder.add_circ(&xor_128_circ); - let siv = builder.add_circ(&xor_32_circ); - let mask_h = builder.add_circ(&xor_128_circ); - let mask_gctr = builder.add_circ(&xor_128_circ); - let masked_h = builder.add_circ(&xor_128_circ); - let masked_gctr = builder.add_circ(&xor_128_circ); - let masked_ectr = builder.add_circ(&xor_128_circ); - let masked_vd = builder.add_circ(&xor_96_circ); - - // swk - builder.connect( - &n_swk[..], - &swk.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_swk[..], - &swk.input(1).expect("nbit_xor missing input 1")[..], - ); - let swk = swk.output(0).expect("nbit_xor missing output 0"); - - // siv - builder.connect( - &n_siv[..], - &siv.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_siv[..], - &siv.input(1).expect("nbit_xor missing input 1")[..], - ); - let siv = siv.output(0).expect("nbit_xor missing output 0"); - - // Compute p1 - let sha256_p1_msg = sha256_p1.input(0).expect("sha256 missing input 0"); - builder.connect(&p1_inner_hash[..], &sha256_p1_msg[256..]); - // append a single '1' bit - builder.connect(&[const_one[0]], &[sha256_p1_msg[255]]); - // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + 64) is a multiple of 512 - builder.connect(&[const_zero[0]; 239], &sha256_p1_msg[16..255]); - // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits - // L = 768 = 0x0300 - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &sha256_p1_msg[..16], - &[0x00, 0x03], - ); - builder.connect( - &p1_outer_state[..], - &sha256_p1.input(1).expect("sha256 missing input 1")[..], - ); - let p1 = sha256_p1.output(0).expect("sha256 missing output 0"); - - // Compute H - builder.connect(&swk[..], &aes_h.input(0).expect("aes missing input 0")[..]); - // encrypt all zeroes - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &aes_h.input(1).expect("aes missing input 1")[..], - &[0u8; 16], - ); - let h = aes_h.output(0).expect("aes missing output 0"); - - // Compute GCTR - builder.connect( - &swk[..], - &aes_gctr.input(0).expect("aes missing input 0")[..], - ); - let aes_gctr_m = aes_gctr.input(1).expect("aes missing input 1"); - builder.connect(&siv[..], &aes_gctr_m[96..]); - builder.connect(&nonce[..], &aes_gctr_m[32..96]); - // CTR (0x1) - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &aes_gctr_m[..32], - &[0x01, 0x00, 0x00, 0x00], - ); - let gctr = aes_gctr.output(0).expect("aes missing output 0"); - - // Compute ECTR - builder.connect( - &swk[..], - &aes_ectr.input(0).expect("aes missing input 0")[..], - ); - let aes_ectr_m = aes_ectr.input(1).expect("aes missing input 1"); - builder.connect(&siv[..], &aes_ectr_m[96..]); - builder.connect(&nonce[..], &aes_ectr_m[32..96]); - // CTR (0x2) - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &aes_ectr_m[..32], - &[0x02, 0x00, 0x00, 0x00], - ); - let ectr = aes_ectr.output(0).expect("aes missing output 0"); - - // H mask - builder.connect( - &n_h_mask[..], - &mask_h.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_h_mask[..], - &mask_h.input(1).expect("nbit_xor missing input 1")[..], - ); - let mask_h = mask_h.output(0).expect("nbit_xor missing output 0"); - - // GCTR mask - builder.connect( - &n_gctr_mask[..], - &mask_gctr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_gctr_mask[..], - &mask_gctr.input(1).expect("nbit_xor missing input 1")[..], - ); - let mask_gctr = mask_gctr.output(0).expect("nbit_xor missing output 0"); - - // Apply H mask - builder.connect( - &mask_h[..], - &masked_h.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &h[..], - &masked_h.input(1).expect("nbit_xor missing input 1")[..], - ); - - // Apply GCTR mask - builder.connect( - &mask_gctr[..], - &masked_gctr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &gctr[..], - &masked_gctr.input(1).expect("nbit_xor missing input 1")[..], - ); - - // Apply ECTR mask - builder.connect( - &ectr[..], - &masked_ectr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_ectr_mask[..], - &masked_ectr.input(1).expect("nbit_xor missing input 1")[..], - ); - - // Apply VD mask - builder.connect( - &u_vd_mask[..], - &masked_vd.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &p1[160..], - &masked_vd.input(1).expect("nbit_xor missing input 1")[..], - ); - - let mut builder = builder.build_gates(); - - let out_h = builder.add_output( - "MASKED_H", - "16-byte masked (N_H_MASK + U_H_MASK) H", - ValueType::Bytes, - 128, - ); - let out_gctr = builder.add_output( - "MASKED_GCTR", - "16-byte masked (N_GCTR_MASK + U_GCTR_MASK) GCTR", - ValueType::Bytes, - 128, - ); - let out_ectr = builder.add_output( - "MASKED_ECTR", - "16-byte masked (U_ECTR_MASK) encrypted counter", - ValueType::Bytes, - 128, - ); - let out_vd = builder.add_output( - "MASKED_VD", - "12-byte masked (U_VD_MASK) server verify data", - ValueType::Bytes, - 96, - ); - - builder.connect( - &masked_h.output(0).expect("nbit_xor missing output 0")[..], - &out_h[..], - ); - builder.connect( - &masked_gctr.output(0).expect("nbit_xor missing output 0")[..], - &out_gctr[..], - ); - builder.connect( - &masked_ectr.output(0).expect("nbit_xor missing output 0")[..], - &out_ectr[..], - ); - builder.connect( - &masked_vd.output(0).expect("nbit_xor missing output 0")[..], - &out_vd[..], - ); - - builder.build_circuit().expect("failed to build c5") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::{finalize_sha256_digest, test_circ}; - use aes::{Aes128, BlockEncrypt, NewBlockCipher}; - use generic_array::GenericArray; - use mpc_circuits::Value; - use rand::{thread_rng, Rng}; - - #[test] - #[ignore = "expensive"] - fn test_c5() { - let circ = c5(); - - let mut rng = thread_rng(); - - let n_outer_hash_state_p1: [u32; 8] = rng.gen(); - let n_swk: [u8; 16] = rng.gen(); - let n_siv: [u8; 4] = rng.gen(); - let n_h_mask: [u8; 16] = rng.gen(); - let n_gctr_mask: [u8; 16] = rng.gen(); - let u_inner_hash_state_p1: [u8; 32] = rng.gen(); - let u_swk: [u8; 16] = rng.gen(); - let u_siv: [u8; 4] = rng.gen(); - let nonce: [u8; 8] = rng.gen(); - let u_h_mask: [u8; 16] = rng.gen(); - let u_gctr_mask: [u8; 16] = rng.gen(); - let u_ectr_mask: [u8; 16] = rng.gen(); - let u_vd_mask: [u8; 12] = rng.gen(); - - // finalize the hash to get p1 - let p1 = finalize_sha256_digest(n_outer_hash_state_p1, 64, &u_inner_hash_state_p1); - let mut verify_data = [0u8; 12]; - verify_data.copy_from_slice(&p1[0..12]); - - // combine key shares - let swk = n_swk - .iter() - .zip(u_swk) - .map(|(n, u)| n ^ u) - .collect::>(); - let siv = n_siv - .iter() - .zip(u_siv) - .map(|(n, u)| n ^ u) - .collect::>(); - - // set AES key - let key = GenericArray::clone_from_slice(&swk); - let cipher = Aes128::new(&key); - - // AES-ECB encrypt 0, get MAC key - let mut z = GenericArray::clone_from_slice(&[0u8; 16]); - cipher.encrypt_block(&mut z); - let mac_key = z; - - // AES-ECB encrypt a block with counter==1 and nonce from Server_Finished, get GCTR block - let counter: [u8; 4] = 1u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&siv); - msg[4..12].copy_from_slice(&nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let gctr_block = msg; - - // AES-ECB encrypt a block with counter==2 and nonce from Server_Finished - let counter: [u8; 4] = 2u32.to_be_bytes(); - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&siv); - msg[4..12].copy_from_slice(&nonce); - msg[12..16].copy_from_slice(&counter); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let ectr = msg; - - // XOR MAC key and GCTR block with Notary's mask and then with User's mask - let mac_key_masked = mac_key - .iter() - .zip(n_h_mask) - .zip(u_h_mask) - .map(|((v, n), u)| v ^ n ^ u) - .collect::>(); - let gctr_block_masked = gctr_block - .iter() - .zip(n_gctr_mask) - .zip(u_gctr_mask) - .map(|((v, n), u)| v ^ n ^ u) - .collect::>(); - - // XOR the first block and verify_data with User's mask - let ectr_masked = ectr - .iter() - .zip(u_ectr_mask) - .map(|(v, u)| v ^ u) - .collect::>(); - let verify_data_masked = verify_data - .iter() - .zip(u_vd_mask) - .map(|(v, u)| v ^ u) - .collect::>(); - - test_circ( - &circ, - &[ - Value::Bytes( - n_outer_hash_state_p1 - .into_iter() - .rev() - .map(|v| v.to_le_bytes()) - .flatten() - .collect::>(), - ), - Value::Bytes(n_swk.into_iter().rev().collect()), - Value::Bytes(n_siv.into_iter().rev().collect()), - Value::Bytes(n_h_mask.into_iter().rev().collect()), - Value::Bytes(n_gctr_mask.into_iter().rev().collect()), - Value::Bytes(u_inner_hash_state_p1.into_iter().rev().collect()), - Value::Bytes(u_swk.into_iter().rev().collect()), - Value::Bytes(u_siv.into_iter().rev().collect()), - Value::Bytes(nonce.into_iter().rev().collect()), - Value::Bytes(u_h_mask.into_iter().rev().collect()), - Value::Bytes(u_gctr_mask.into_iter().rev().collect()), - Value::Bytes(u_ectr_mask.into_iter().rev().collect()), - Value::Bytes(u_vd_mask.into_iter().rev().collect()), - ], - &[ - Value::Bytes(mac_key_masked.into_iter().rev().collect()), - Value::Bytes(gctr_block_masked.into_iter().rev().collect()), - Value::Bytes(ectr_masked.into_iter().rev().collect()), - Value::Bytes(verify_data_masked.into_iter().rev().collect()), - ], - ); - } -} diff --git a/tls/tls-circuits/src/c6.rs b/tls/tls-circuits/src/c6.rs deleted file mode 100644 index 14ce3a314..000000000 --- a/tls/tls-circuits/src/c6.rs +++ /dev/null @@ -1,182 +0,0 @@ -use std::sync::Arc; - -use mpc_circuits::{ - builder::CircuitBuilder, circuits::nbit_xor, Circuit, ValueType, AES_128_REVERSE, -}; - -/// TLS stage 6 -/// -/// Encrypt plaintext or decrypt ciphertext in AES-CTR mode -/// -/// T_IN could also just be used as a mask for the encrypted counter-block. -/// -/// Inputs: -/// -/// 0. N_K: 16-byte Notary share of write-key -/// 1. N_IV: 4-byte Notary share of IV -/// 2. U_K: 16-byte User share of write-key -/// 3. U_IV: 4-byte User share of IV -/// 4. T_IN: 16-byte text (plaintext or ciphertext) -/// 5. NONCE: 8-byte Explicit Nonce -/// 6. CTR: U32 Counter -/// -/// Outputs: -/// -/// 0. T_OUT: 16-byte output (plaintext or ciphertext) -pub fn c6() -> Arc { - let mut builder = CircuitBuilder::new("c6", "", "0.1.0"); - - let n_k = builder.add_input( - "N_K", - "16-byte Notary write-key share", - ValueType::Bytes, - 128, - ); - let n_iv = builder.add_input("N_SIV", "4-byte Notary share of IV", ValueType::Bytes, 32); - let c_k = builder.add_input( - "U_SWK", - "16-byte User share of write-key", - ValueType::Bytes, - 128, - ); - let c_iv = builder.add_input("U_SIV", "4-byte User share of IV", ValueType::Bytes, 32); - let t_in = builder.add_input( - "T_IN", - "16-byte text (plaintext or ciphertext)", - ValueType::Bytes, - 128, - ); - let nonce = builder.add_input("NONCE", "8-byte Explicit Nonce", ValueType::Bytes, 64); - let ctr = builder.add_input("CTR", "U32 Counter", ValueType::U32, 32); - - let mut builder = builder.build_inputs(); - - let aes = Circuit::load_bytes(AES_128_REVERSE).expect("failed to load aes_128_reverse circuit"); - let xor_128_circ = nbit_xor(128); - let xor_32_circ = nbit_xor(32); - - let aes_ectr = builder.add_circ(&aes); - let k = builder.add_circ(&xor_128_circ); - let iv = builder.add_circ(&xor_32_circ); - let t_out = builder.add_circ(&xor_128_circ); - - // Compute write-key - builder.connect(&n_k[..], &k.input(0).expect("nbit_xor missing input 0")[..]); - builder.connect(&c_k[..], &k.input(1).expect("nbit_xor missing input 1")[..]); - let k = k.output(0).expect("nbit_xor missing output 0"); - - // iv - builder.connect( - &n_iv[..], - &iv.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &c_iv[..], - &iv.input(1).expect("nbit_xor missing input 1")[..], - ); - let iv = iv.output(0).expect("nbit_xor missing output 0"); - - // Compute encrypted counter-block - builder.connect(&k[..], &aes_ectr.input(0).expect("aes missing input 0")[..]); - let aes_ectr_m = aes_ectr.input(1).expect("aes missing input 1"); - // Implicit nonce - builder.connect(&iv[..], &aes_ectr_m[96..]); - // Explicit nonce - builder.connect(&nonce[..], &aes_ectr_m[32..96]); - // Counter - builder.connect(&ctr[..], &aes_ectr_m[..32]); - let ectr = aes_ectr.output(0).expect("aes missing output 0"); - - // Apply text - builder.connect( - &ectr[..], - &t_out.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &t_in[..], - &t_out.input(1).expect("nbit_xor missing input 1")[..], - ); - - let mut builder = builder.build_gates(); - - let out_ectr = builder.add_output( - "T_OUT", - "16-byte output (plaintext or ciphertext)", - ValueType::Bytes, - 128, - ); - - builder.connect( - &t_out.output(0).expect("nbit_xor missing output 0")[..], - &out_ectr[..], - ); - - builder.build_circuit().expect("failed to build c6") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::test_circ; - use aes::{Aes128, BlockEncrypt, NewBlockCipher}; - use generic_array::GenericArray; - use mpc_circuits::Value; - use rand::{thread_rng, Rng}; - - #[test] - #[ignore = "expensive"] - fn test_c6() { - let circ = c6(); - - let mut rng = thread_rng(); - - let n_k: [u8; 16] = rng.gen(); - let n_iv: [u8; 4] = rng.gen(); - let u_k: [u8; 16] = rng.gen(); - let u_iv: [u8; 4] = rng.gen(); - let t_in: [u8; 16] = rng.gen(); - let explicit_nonce: [u8; 8] = rng.gen(); - let ctr: u32 = rng.gen(); - - // combine key shares - let k = n_k.iter().zip(u_k).map(|(n, u)| n ^ u).collect::>(); - let iv = n_iv - .iter() - .zip(u_iv) - .map(|(n, u)| n ^ u) - .collect::>(); - - // set AES key - let key = GenericArray::clone_from_slice(&k); - let cipher = Aes128::new(&key); - - // AES-ECB encrypt a block with counter and nonce - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&iv); - msg[4..12].copy_from_slice(&explicit_nonce); - msg[12..16].copy_from_slice(&ctr.to_be_bytes()); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let ectr = msg; - - let t_out = ectr - .iter() - .zip(t_in) - .map(|(v, u)| v ^ u) - .collect::>(); - - test_circ( - &circ, - &[ - Value::Bytes(n_k.into_iter().rev().collect()), - Value::Bytes(n_iv.into_iter().rev().collect()), - Value::Bytes(u_k.into_iter().rev().collect()), - Value::Bytes(u_iv.into_iter().rev().collect()), - Value::Bytes(t_in.into_iter().rev().collect()), - Value::Bytes(explicit_nonce.into_iter().rev().collect()), - Value::U32(ctr), - ], - &[Value::Bytes(t_out.into_iter().rev().collect())], - ); - } -} diff --git a/tls/tls-circuits/src/c7.rs b/tls/tls-circuits/src/c7.rs deleted file mode 100644 index 370d298bc..000000000 --- a/tls/tls-circuits/src/c7.rs +++ /dev/null @@ -1,239 +0,0 @@ -use std::sync::Arc; - -use mpc_circuits::{ - builder::{map_le_bytes, CircuitBuilder}, - circuits::nbit_xor, - Circuit, ValueType, AES_128_REVERSE, -}; - -/// TLS stage 7 -/// -/// Compute GCTR block -/// -/// Inputs: -/// -/// 0. N_CWK: 16-byte Notary share of client write-key -/// 1. N_CIV: 4-byte Notary share of client IV -/// 2. N_MASK: 16-byte User mask for GCTR -/// 3. U_CWK: 16-byte User share of client write-key -/// 4. U_CIV: 4-byte User share of client IV -/// 5. U_MASK: 16-byte User mask for GCTR -/// 6. NONCE: U16 Nonce -/// -/// Outputs: -/// -/// 0. MASKED_GCTR: 16-byte masked (N_MASK + U_MASK) GCTR -pub fn c7() -> Arc { - let mut builder = CircuitBuilder::new("c7", "", "0.1.0"); - - let n_cwk = builder.add_input( - "N_CWK", - "16-byte Notary client write-key share", - ValueType::Bytes, - 128, - ); - let n_civ = builder.add_input( - "N_SIV", - "4-byte Notary share of client IV", - ValueType::Bytes, - 32, - ); - let n_mask = builder.add_input( - "N_MASK", - "16-byte Notary mask for GCTR", - ValueType::Bytes, - 128, - ); - let u_cwk = builder.add_input( - "U_SWK", - "16-byte User share of client write-key", - ValueType::Bytes, - 128, - ); - let u_civ = builder.add_input( - "U_SIV", - "4-byte User share of client IV", - ValueType::Bytes, - 32, - ); - let u_mask = builder.add_input( - "U_MASK", - "16-byte User mask for GCTR", - ValueType::Bytes, - 128, - ); - let nonce = builder.add_input("NONCE", "U64 Nonce", ValueType::U64, 64); - let const_zero = builder.add_input( - "const_zero", - "input that is always 0", - ValueType::ConstZero, - 1, - ); - let const_one = builder.add_input( - "const_one", - "input that is always 1", - ValueType::ConstOne, - 1, - ); - - let mut builder = builder.build_inputs(); - - let aes = Circuit::load_bytes(AES_128_REVERSE).expect("failed to load aes_128_reverse circuit"); - let xor_128_circ = nbit_xor(128); - let xor_32_circ = nbit_xor(32); - - let aes_gctr = builder.add_circ(&aes); - let cwk = builder.add_circ(&xor_128_circ); - let civ = builder.add_circ(&xor_32_circ); - let mask_gctr = builder.add_circ(&xor_128_circ); - let masked_gctr = builder.add_circ(&xor_128_circ); - - // cwk - builder.connect( - &n_cwk[..], - &cwk.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_cwk[..], - &cwk.input(1).expect("nbit_xor missing input 1")[..], - ); - let cwk = cwk.output(0).expect("nbit_xor missing output 0"); - - // civ - builder.connect( - &n_civ[..], - &civ.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_civ[..], - &civ.input(1).expect("nbit_xor missing input 1")[..], - ); - let civ = civ.output(0).expect("nbit_xor missing output 0"); - - // Compute GCTR - builder.connect( - &cwk[..], - &aes_gctr.input(0).expect("aes missing input 0")[..], - ); - let aes_gctr_m = aes_gctr.input(1).expect("aes missing input 1"); - builder.connect(&civ[..], &aes_gctr_m[96..]); - builder.connect(&nonce[..], &aes_gctr_m[32..96]); - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &aes_gctr_m[..32], - &[0x01, 0x00, 0x00, 0x00], - ); - let gctr = aes_gctr.output(0).expect("aes missing output 0"); - - // GCTR mask - builder.connect( - &n_mask[..], - &mask_gctr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &u_mask[..], - &mask_gctr.input(1).expect("nbit_xor missing input 1")[..], - ); - let mask_gctr = mask_gctr.output(0).expect("nbit_xor missing output 0"); - - // Apply GCTR mask - builder.connect( - &gctr[..], - &masked_gctr.input(0).expect("nbit_xor missing input 0")[..], - ); - builder.connect( - &mask_gctr[..], - &masked_gctr.input(1).expect("nbit_xor missing input 1")[..], - ); - - let mut builder = builder.build_gates(); - - let out_gctr = builder.add_output( - "MASKED_GCTR", - "16-byte masked (N_MASK + U_MASK) GCTR", - ValueType::Bytes, - 128, - ); - - builder.connect( - &masked_gctr.output(0).expect("nbit_xor missing output 0")[..], - &out_gctr[..], - ); - - builder.build_circuit().expect("failed to build c7") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_helpers::test_circ; - use aes::{Aes128, BlockEncrypt, NewBlockCipher}; - use generic_array::GenericArray; - use mpc_circuits::Value; - use rand::{thread_rng, Rng}; - - #[test] - #[ignore = "expensive"] - fn test_c7() { - let circ = c7(); - - let mut rng = thread_rng(); - - let n_cwk: [u8; 16] = rng.gen(); - let n_civ: [u8; 4] = rng.gen(); - let n_mask: [u8; 16] = rng.gen(); - let u_cwk: [u8; 16] = rng.gen(); - let u_civ: [u8; 4] = rng.gen(); - let u_mask: [u8; 16] = rng.gen(); - let nonce: u64 = rng.gen(); - - // combine key shares - let cwk = n_cwk - .iter() - .zip(u_cwk) - .map(|(n, u)| n ^ u) - .collect::>(); - let civ = n_civ - .iter() - .zip(u_civ) - .map(|(n, u)| n ^ u) - .collect::>(); - - // set AES key - let key = GenericArray::clone_from_slice(&cwk); - let cipher = Aes128::new(&key); - - // AES-ECB encrypt a block with counter and nonce - let mut msg = [0u8; 16]; - msg[0..4].copy_from_slice(&civ); - msg[4..12].copy_from_slice(&nonce.to_be_bytes()); - msg[12..16].copy_from_slice(&1u32.to_be_bytes()); - let mut msg = GenericArray::clone_from_slice(&msg); - cipher.encrypt_block(&mut msg); - let gctr = msg; - - // XOR the first block and verify_data with User's mask - let gctr_masked = gctr - .iter() - .zip(n_mask) - .zip(u_mask) - .map(|((v, n), u)| v ^ n ^ u) - .collect::>(); - - test_circ( - &circ, - &[ - Value::Bytes(n_cwk.into_iter().rev().collect()), - Value::Bytes(n_civ.into_iter().rev().collect()), - Value::Bytes(n_mask.into_iter().rev().collect()), - Value::Bytes(u_cwk.into_iter().rev().collect()), - Value::Bytes(u_civ.into_iter().rev().collect()), - Value::Bytes(u_mask.into_iter().rev().collect()), - Value::U64(nonce), - ], - &[Value::Bytes(gctr_masked.into_iter().rev().collect())], - ); - } -} diff --git a/tls/tls-circuits/src/combine_pms_shares.rs b/tls/tls-circuits/src/combine_pms_shares.rs deleted file mode 100644 index 7780bd588..000000000 --- a/tls/tls-circuits/src/combine_pms_shares.rs +++ /dev/null @@ -1,146 +0,0 @@ -use std::sync::Arc; - -use mpc_circuits::{ - builder::{map_le_bytes, CircuitBuilder}, - circuits::nbit_add_mod, - Circuit, ValueType, -}; - -/// Combines two PMS shares -/// -/// Each share must already be reduced mod P -pub fn combine_pms_shares() -> Arc { - let mut builder = CircuitBuilder::new("pms_shares", "", "0.1.0"); - - let a = builder.add_input( - "PMS_SHARE_A", - "256-bit PMS Additive Share", - ValueType::Bytes, - 256, - ); - let b = builder.add_input( - "PMS_SHARE_B", - "256-bit PMS Additive Share", - ValueType::Bytes, - 256, - ); - let const_zero = builder.add_input( - "const_zero", - "input that is always 0", - ValueType::ConstZero, - 1, - ); - let const_one = builder.add_input( - "const_one", - "input that is always 1", - ValueType::ConstOne, - 1, - ); - - let mut builder = builder.build_inputs(); - - let add_mod = builder.add_circ(&nbit_add_mod(256)); - let add_mod_a = add_mod.input(0).expect("add mod is missing input 0"); - let add_mod_b = add_mod.input(1).expect("add mod is missing input 1"); - let add_mod_mod = add_mod.input(2).expect("add mod is missing input 2"); - let add_mod_const_zero = add_mod.input(3).expect("add mod is missing input 3"); - let add_mod_const_one = add_mod.input(4).expect("add mod is missing input 4"); - let add_mod_out = add_mod.output(0).expect("add mod is missing output 0"); - - builder.connect(&[const_zero[0]], &[add_mod_const_zero[0]]); - builder.connect(&[const_one[0]], &[add_mod_const_one[0]]); - - builder.connect(&a[..], &add_mod_a[..]); - builder.connect(&b[..], &add_mod_b[..]); - - // map p256 prime to mod - map_le_bytes( - &mut builder, - const_zero[0], - const_one[0], - &add_mod_mod[..], - &[ - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0xFF, 0xFF, 0xFF, 0xFF, - ], - ); - - let mut builder = builder.build_gates(); - let out = builder.add_output("PMS", "Pre-master Secret", ValueType::Bytes, 256); - - builder.connect(&add_mod_out[..], &out[..]); - - builder - .build_circuit() - .expect("failed to build combine_pms_shares") -} - -#[cfg(test)] -mod tests { - use super::*; - use mpc_circuits::{circuits::test_circ, Value}; - use num_bigint::BigUint; - use num_traits::One; - - /// NIST P-256 Prime - const P: &str = "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"; - - #[test] - #[ignore = "expensive"] - fn test_combine_pms_shares() { - let circ = combine_pms_shares(); - let p = BigUint::parse_bytes(P.as_bytes(), 16).unwrap(); - let mut one = vec![0x00; 32]; - one[0] = 1; - let mut two = vec![0x00; 32]; - two[0] = 2; - // 0 + 0 mod p = 0 - test_circ( - &circ, - &[Value::Bytes(vec![0x00; 32]), Value::Bytes(vec![0x00; 32])], - &[Value::Bytes(vec![0x00; 32])], - ); - // 0 + 1 mod p = 1 - test_circ( - &circ, - &[Value::Bytes(vec![0x00; 32]), Value::Bytes(one.clone())], - &[Value::Bytes(one.clone())], - ); - let a = [vec![255; 16], vec![0; 16]].concat(); - let b = [vec![255; 16], vec![0; 16]].concat(); - let expected = [vec![254], vec![255; 15], vec![1], vec![0; 15]].concat(); - test_circ( - &circ, - &[Value::Bytes(a), Value::Bytes(b)], - &[Value::Bytes(expected)], - ); - let p_minus_one = p.clone() - BigUint::one(); - // (p + p - 2) mod p = p - 2 - test_circ( - &circ, - &[ - Value::Bytes(p_minus_one.to_bytes_le()), - Value::Bytes(p_minus_one.to_bytes_le()), - ], - &[Value::Bytes( - ((p_minus_one.clone() + p_minus_one) % p.clone()).to_bytes_le(), - )], - ); - // (p - 1) + 2 mod p = 1 - test_circ( - &circ, - &[ - Value::Bytes((p.clone() - BigUint::one()).to_bytes_le()), - Value::Bytes(two.clone()), - ], - &[Value::Bytes(one.clone())], - ); - // p + 0 mod p = 0 - test_circ( - &circ, - &[Value::Bytes(p.to_bytes_le()), Value::Bytes(vec![0; 32])], - &[Value::Bytes(vec![0; 32])], - ); - } -} diff --git a/tls/tls-circuits/src/lib.rs b/tls/tls-circuits/src/lib.rs deleted file mode 100644 index 5200c3543..000000000 --- a/tls/tls-circuits/src/lib.rs +++ /dev/null @@ -1,88 +0,0 @@ -mod c1; -mod c2; -mod c3; -mod c4; -mod c5; -mod c6; -mod c7; -mod combine_pms_shares; - -pub use c1::c1; -pub use c2::c2; -pub use c3::c3; -pub use c4::c4; -pub use c5::c5; -pub use c6::c6; -pub use c7::c7; -pub use combine_pms_shares::combine_pms_shares; - -static SHA256_STATE: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - -#[cfg(test)] -mod test_helpers { - use std::slice::from_ref; - - use generic_array::typenum::U64; - use sha2::{ - compress256, - digest::block_buffer::{BlockBuffer, Eager}, - }; - - use mpc_circuits::{Circuit, Value, WireGroup}; - - pub fn test_circ(circ: &Circuit, inputs: &[Value], expected: &[Value]) { - let inputs: Vec = inputs - .iter() - .zip(circ.inputs()) - .map(|(value, input)| input.clone().to_value(value.clone()).unwrap()) - .collect(); - let outputs = circ.evaluate(&inputs).unwrap(); - for (output, expected) in outputs.iter().zip(expected) { - if output.value() != expected { - let report = format!( - "Circuit {}\n{}{}Expected: {:?}", - circ.description(), - inputs - .iter() - .enumerate() - .map(|(id, input)| format!("Input {}: {:?}\n", id, input.value())) - .collect::>() - .join(""), - format!("Output {}: {:?}\n", output.index(), output.value()), - expected - ); - panic!("{}", report.to_string()); - } - } - } - - pub fn partial_sha256_digest(input: &[u8]) -> [u32; 8] { - let mut state = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, - 0x5be0cd19, - ]; - for b in input.chunks_exact(64) { - let block = generic_array::GenericArray::from_slice(b); - sha2::compress256(&mut state, &[*block]); - } - state - } - - pub fn finalize_sha256_digest(mut state: [u32; 8], pos: usize, input: &[u8]) -> [u8; 32] { - let mut buffer = BlockBuffer::::default(); - buffer.digest_blocks(input, |b| compress256(&mut state, b)); - buffer.digest_pad( - 0x80, - &(((input.len() + pos) * 8) as u64).to_be_bytes(), - |b| compress256(&mut state, from_ref(b)), - ); - - let mut out: [u8; 32] = [0; 32]; - for (chunk, v) in out.chunks_exact_mut(4).zip(state.iter()) { - chunk.copy_from_slice(&v.to_be_bytes()); - } - out - } -}