feat(prf): reduced MPC variant (#735)

* feat(prf): reduced MPC variant

* move sending `client_random` from `alloc` to `preprocess`

* rename `Config` -> `Mode` and rename variants

* add feedback for handling of prf config

* fix formatting to nightly

* simplify `MpcPrf`

* improve external flush handling

* improve control flow

* improved inner control flow for normal prf version

* rename leftover `config` -> `mode`

* remove unnecessary pub(crate)

* rewrite state flow for reduced prf

* improve state transition for reduced prf

* repair prf bench

* WIP: Adapting to new `Sha256` from mpz

* repair failing test

* fixed all tests

* remove output decoding for p

* do not use mod.rs file hierarchy

* remove pub(crate) from function

* improve config handling

* use `Array::try_from`

* simplify hmac to function

* remove `merge_vecs`

* move `mark_public` to allocation

* minor fixes

* simplify state logic for reduced prf even more

* simplify reduced prf even more

* set reduced prf as default

* temporarily fix commit for mpz

* add part of feedback

* simplify state transition

* adapt comment

* improve state transition in flush

* simplify flush

* fix wasm prover config

---------

Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>
This commit is contained in:
th4s
2025-05-13 18:26:43 +02:00
committed by GitHub
parent 2c500b13bd
commit 6ccf102ec8
35 changed files with 1972 additions and 1329 deletions

View File

@@ -8,8 +8,7 @@ members = [
"crates/common",
"crates/components/deap",
"crates/components/cipher",
#"crates/components/hmac-sha256",
#"crates/components/hmac-sha256-circuits",
"crates/components/hmac-sha256",
"crates/components/key-exchange",
"crates/core",
"crates/data-fixtures",
@@ -57,8 +56,7 @@ tlsn-core = { path = "crates/core" }
tlsn-data-fixtures = { path = "crates/data-fixtures" }
tlsn-deap = { path = "crates/components/deap" }
tlsn-formats = { path = "crates/formats" }
#tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" }
#tlsn-hmac-sha256-circuits = { path = "crates/components/hmac-sha256-circuits" }
tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" }
tlsn-key-exchange = { path = "crates/components/key-exchange" }
tlsn-mpc-tls = { path = "crates/mpc-tls" }
tlsn-prover = { path = "crates/prover" }
@@ -71,18 +69,19 @@ tlsn-tls-core = { path = "crates/tls/core" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
tlsn-verifier = { path = "crates/verifier" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }

View File

@@ -18,10 +18,10 @@ mpz-core = { workspace = true }
mpz-garble = { workspace = true }
mpz-ot = { workspace = true, features = ["ideal"] }
tlsn-benches-library = { workspace = true }
tlsn-benches-browser-native = { workspace = true, optional = true}
tlsn-benches-browser-native = { workspace = true, optional = true }
tlsn-common = { workspace = true }
tlsn-core = { workspace = true }
#tlsn-hmac-sha256 = { workspace = true }
tlsn-hmac-sha256 = { workspace = true }
tlsn-prover = { workspace = true }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
@@ -30,7 +30,7 @@ tlsn-verifier = { workspace = true }
anyhow = { workspace = true }
async-trait = { workspace = true }
charming = {version = "0.3.1", features = ["ssr"]}
charming = { version = "0.3.1", features = ["ssr"] }
csv = "1.3.0"
dhat = { version = "0.3.3" }
env_logger = { version = "0.6.0", default-features = false }
@@ -46,7 +46,8 @@ tokio = { workspace = true, features = [
] }
tokio-util = { workspace = true }
toml = "0.8.11"
tracing-subscriber = {workspace = true, features = ["env-filter"]}
tracing-subscriber = { workspace = true, features = ["env-filter"] }
rand = { workspace = true }
[[bin]]
name = "bench"

View File

@@ -1,6 +1,5 @@
pub mod config;
pub mod metrics;
mod preprocess;
pub mod prover;
pub mod prover_main;
pub mod verifier_main;

View File

@@ -1,5 +0,0 @@
use hmac_sha256::build_circuits;
pub async fn preprocess_prf_circuits() {
build_circuits().await;
}

View File

@@ -14,7 +14,6 @@ use std::{
use crate::{
config::{BenchInstance, Config},
metrics::Metrics,
preprocess::preprocess_prf_circuits,
set_interface, PROVER_INTERFACE,
};
use anyhow::Context;
@@ -58,10 +57,6 @@ pub async fn prover_main(is_memory_profiling: bool) -> anyhow::Result<()> {
.open("metrics.csv")
.context("failed to open metrics file")?;
// Preprocess the PRF circuits as they are allocating a lot of memory, which
// don't need to be accounted for in the benchmarks.
preprocess_prf_circuits().await;
{
let mut metric_wrt = WriterBuilder::new()
// If file is not empty, assume that the CSV header is already present in the file.

View File

@@ -4,7 +4,6 @@
use crate::{
config::{BenchInstance, Config},
preprocess::preprocess_prf_circuits,
set_interface, VERIFIER_INTERFACE,
};
use tls_core::verify::WebPkiVerifier;
@@ -40,10 +39,6 @@ pub async fn verifier_main(is_memory_profiling: bool) -> anyhow::Result<()> {
.await
.context("failed to bind to port")?;
// Preprocess the PRF circuits as they are allocating a lot of memory, which
// don't need to be accounted for in the benchmarks.
preprocess_prf_circuits().await;
for bench in config.benches {
for instance in bench.flatten() {
if is_memory_profiling && !instance.memory_profile {

View File

@@ -41,6 +41,9 @@ pub struct ProtocolConfig {
/// of the MPC-TLS connection.
#[builder(default = "true")]
defer_decryption_from_start: bool,
/// Network settings.
#[builder(default)]
network: NetworkSetting,
/// Version that is being run by prover/verifier.
#[builder(setter(skip), default = "VERSION.clone()")]
version: Version,
@@ -95,6 +98,11 @@ impl ProtocolConfig {
pub fn defer_decryption_from_start(&self) -> bool {
self.defer_decryption_from_start
}
/// Returns the network settings.
pub fn network(&self) -> NetworkSetting {
self.network
}
}
/// Protocol configuration validator used by checker (i.e. verifier) to perform
@@ -216,6 +224,24 @@ impl ProtocolConfigValidator {
}
}
/// Settings for the network environment.
///
/// Provides optimization options to adapt the protocol to different network
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NetworkSetting {
/// Prefers a bandwidth-heavy protocol.
Bandwidth,
/// Prefers a latency-heavy protocol.
Latency,
}
impl Default for NetworkSetting {
fn default() -> Self {
Self::Bandwidth
}
}
/// A ProtocolConfig error.
#[derive(thiserror::Error, Debug)]
pub struct ProtocolConfigError {

View File

@@ -1,22 +0,0 @@
[package]
name = "tlsn-hmac-sha256-circuits"
authors = ["TLSNotary Team"]
description = "The 2PC circuits for TLS HMAC-SHA256 PRF"
keywords = ["tls", "mpc", "2pc", "hmac", "sha256"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.11-pre"
edition = "2021"
[lints]
workspace = true
[lib]
name = "hmac_sha256_circuits"
[dependencies]
mpz-circuits = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
ring = { workspace = true }

View File

@@ -1,159 +0,0 @@
use std::cell::RefCell;
use mpz_circuits::{
circuits::{sha256, sha256_compress, sha256_compress_trace, sha256_trace},
types::{U32, U8},
BuilderState, Tracer,
};
static SHA256_INITIAL_STATE: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
];
/// Returns the outer and inner states of HMAC-SHA256 with the provided key.
///
/// Outer state is H(key ⊕ opad)
///
/// Inner state is H(key ⊕ ipad)
///
/// # Arguments
///
/// * `builder_state` - Reference to builder state.
/// * `key` - N-byte key (must be <= 64 bytes).
pub fn hmac_sha256_partial_trace<'a>(
builder_state: &'a RefCell<BuilderState>,
key: &[Tracer<'a, U8>],
) -> ([Tracer<'a, U32>; 8], [Tracer<'a, U32>; 8]) {
assert!(key.len() <= 64);
let mut opad = [Tracer::new(
builder_state,
builder_state.borrow_mut().get_constant(0x5cu8),
); 64];
let mut ipad = [Tracer::new(
builder_state,
builder_state.borrow_mut().get_constant(0x36u8),
); 64];
key.iter().enumerate().for_each(|(i, k)| {
opad[i] = opad[i] ^ *k;
ipad[i] = ipad[i] ^ *k;
});
let sha256_initial_state: [_; 8] = SHA256_INITIAL_STATE
.map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v)));
let outer_state = sha256_compress_trace(builder_state, sha256_initial_state, opad);
let inner_state = sha256_compress_trace(builder_state, sha256_initial_state, ipad);
(outer_state, inner_state)
}
/// Reference implementation of HMAC-SHA256 partial function.
///
/// Returns the outer and inner states of HMAC-SHA256 with the provided key.
///
/// Outer state is H(key ⊕ opad)
///
/// Inner state is H(key ⊕ ipad)
///
/// # Arguments
///
/// * `key` - N-byte key (must be <= 64 bytes).
pub fn hmac_sha256_partial(key: &[u8]) -> ([u32; 8], [u32; 8]) {
assert!(key.len() <= 64);
let mut opad = [0x5cu8; 64];
let mut ipad = [0x36u8; 64];
key.iter().enumerate().for_each(|(i, k)| {
opad[i] ^= k;
ipad[i] ^= k;
});
let outer_state = sha256_compress(SHA256_INITIAL_STATE, opad);
let inner_state = sha256_compress(SHA256_INITIAL_STATE, ipad);
(outer_state, inner_state)
}
/// HMAC-SHA256 finalization function.
///
/// Returns the HMAC-SHA256 digest of the provided message using existing outer
/// and inner states.
///
/// # Arguments
///
/// * `outer_state` - 256-bit outer state.
/// * `inner_state` - 256-bit inner state.
/// * `msg` - N-byte message.
pub fn hmac_sha256_finalize_trace<'a>(
builder_state: &'a RefCell<BuilderState>,
outer_state: [Tracer<'a, U32>; 8],
inner_state: [Tracer<'a, U32>; 8],
msg: &[Tracer<'a, U8>],
) -> [Tracer<'a, U8>; 32] {
sha256_trace(
builder_state,
outer_state,
64,
&sha256_trace(builder_state, inner_state, 64, msg),
)
}
/// Reference implementation of the HMAC-SHA256 finalization function.
///
/// Returns the HMAC-SHA256 digest of the provided message using existing outer
/// and inner states.
///
/// # Arguments
///
/// * `outer_state` - 256-bit outer state.
/// * `inner_state` - 256-bit inner state.
/// * `msg` - N-byte message.
pub fn hmac_sha256_finalize(outer_state: [u32; 8], inner_state: [u32; 8], msg: &[u8]) -> [u8; 32] {
sha256(outer_state, 64, &sha256(inner_state, 64, msg))
}
#[cfg(test)]
mod tests {
use mpz_circuits::{test_circ, CircuitBuilder};
use super::*;
#[test]
fn test_hmac_sha256_partial() {
let builder = CircuitBuilder::new();
let key = builder.add_array_input::<u8, 48>();
let (outer_state, inner_state) = hmac_sha256_partial_trace(builder.state(), &key);
builder.add_output(outer_state);
builder.add_output(inner_state);
let circ = builder.build().unwrap();
let key = [69u8; 48];
test_circ!(circ, hmac_sha256_partial, fn(&key) -> ([u32; 8], [u32; 8]));
}
#[test]
fn test_hmac_sha256_finalize() {
let builder = CircuitBuilder::new();
let outer_state = builder.add_array_input::<u32, 8>();
let inner_state = builder.add_array_input::<u32, 8>();
let msg = builder.add_array_input::<u8, 47>();
let hash = hmac_sha256_finalize_trace(builder.state(), outer_state, inner_state, &msg);
builder.add_output(hash);
let circ = builder.build().unwrap();
let key = [69u8; 32];
let (outer_state, inner_state) = hmac_sha256_partial(&key);
let msg = [42u8; 47];
test_circ!(
circ,
hmac_sha256_finalize,
fn(outer_state, inner_state, &msg) -> [u8; 32]
);
}
}

View File

@@ -1,61 +0,0 @@
//! HMAC-SHA256 circuits.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod hmac_sha256;
mod prf;
mod session_keys;
mod verify_data;
pub use hmac_sha256::{
hmac_sha256_finalize, hmac_sha256_finalize_trace, hmac_sha256_partial,
hmac_sha256_partial_trace,
};
pub use prf::{prf, prf_trace};
pub use session_keys::{session_keys, session_keys_trace};
pub use verify_data::{verify_data, verify_data_trace};
use mpz_circuits::{Circuit, CircuitBuilder, Tracer};
use std::sync::Arc;
/// Builds session key derivation circuit.
#[tracing::instrument(level = "trace")]
pub fn build_session_keys() -> Arc<Circuit> {
let builder = CircuitBuilder::new();
let pms = builder.add_array_input::<u8, 32>();
let client_random = builder.add_array_input::<u8, 32>();
let server_random = builder.add_array_input::<u8, 32>();
let (cwk, swk, civ, siv, outer_state, inner_state) =
session_keys_trace(builder.state(), pms, client_random, server_random);
builder.add_output(cwk);
builder.add_output(swk);
builder.add_output(civ);
builder.add_output(siv);
builder.add_output(outer_state);
builder.add_output(inner_state);
Arc::new(builder.build().expect("session keys should build"))
}
/// Builds a verify data circuit.
#[tracing::instrument(level = "trace")]
pub fn build_verify_data(label: &[u8]) -> Arc<Circuit> {
let builder = CircuitBuilder::new();
let outer_state = builder.add_array_input::<u32, 8>();
let inner_state = builder.add_array_input::<u32, 8>();
let handshake_hash = builder.add_array_input::<u8, 32>();
let vd = verify_data_trace(
builder.state(),
outer_state,
inner_state,
&label
.iter()
.map(|v| Tracer::new(builder.state(), builder.get_constant(*v).to_inner()))
.collect::<Vec<_>>(),
handshake_hash,
);
builder.add_output(vd);
Arc::new(builder.build().expect("verify data should build"))
}

View File

@@ -1,227 +0,0 @@
//! This module provides an implementation of the HMAC-SHA256 PRF defined in [RFC 5246](https://www.rfc-editor.org/rfc/rfc5246#section-5).
use std::cell::RefCell;
use mpz_circuits::{
types::{U32, U8},
BuilderState, Tracer,
};
use crate::hmac_sha256::{hmac_sha256_finalize, hmac_sha256_finalize_trace};
fn p_hash_trace<'a>(
builder_state: &'a RefCell<BuilderState>,
outer_state: [Tracer<'a, U32>; 8],
inner_state: [Tracer<'a, U32>; 8],
seed: &[Tracer<'a, U8>],
iterations: usize,
) -> Vec<Tracer<'a, U8>> {
// A() is defined as:
//
// A(0) = seed
// A(i) = HMAC_hash(secret, A(i-1))
let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1);
a_cache.push(seed.to_vec());
for i in 0..iterations {
let a_i = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_cache[i]);
a_cache.push(a_i.to_vec());
}
// HMAC_hash(secret, A(i) + seed)
let mut output: Vec<_> = Vec::with_capacity(iterations * 32);
for i in 0..iterations {
let mut a_i_seed = a_cache[i + 1].clone();
a_i_seed.extend_from_slice(seed);
let hash = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_i_seed);
output.extend_from_slice(&hash);
}
output
}
fn p_hash(outer_state: [u32; 8], inner_state: [u32; 8], seed: &[u8], iterations: usize) -> Vec<u8> {
// A() is defined as:
//
// A(0) = seed
// A(i) = HMAC_hash(secret, A(i-1))
let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1);
a_cache.push(seed.to_vec());
for i in 0..iterations {
let a_i = hmac_sha256_finalize(outer_state, inner_state, &a_cache[i]);
a_cache.push(a_i.to_vec());
}
// HMAC_hash(secret, A(i) + seed)
let mut output: Vec<_> = Vec::with_capacity(iterations * 32);
for i in 0..iterations {
let mut a_i_seed = a_cache[i + 1].clone();
a_i_seed.extend_from_slice(seed);
let hash = hmac_sha256_finalize(outer_state, inner_state, &a_i_seed);
output.extend_from_slice(&hash);
}
output
}
/// Computes PRF(secret, label, seed).
///
/// # Arguments
///
/// * `builder_state` - Reference to builder state.
/// * `outer_state` - The outer state of HMAC-SHA256.
/// * `inner_state` - The inner state of HMAC-SHA256.
/// * `seed` - The seed to use.
/// * `label` - The label to use.
/// * `bytes` - The number of bytes to output.
pub fn prf_trace<'a>(
builder_state: &'a RefCell<BuilderState>,
outer_state: [Tracer<'a, U32>; 8],
inner_state: [Tracer<'a, U32>; 8],
seed: &[Tracer<'a, U8>],
label: &[Tracer<'a, U8>],
bytes: usize,
) -> Vec<Tracer<'a, U8>> {
let iterations = bytes / 32 + (bytes % 32 != 0) as usize;
let mut label_seed = label.to_vec();
label_seed.extend_from_slice(seed);
let mut output = p_hash_trace(
builder_state,
outer_state,
inner_state,
&label_seed,
iterations,
);
output.truncate(bytes);
output
}
/// Reference implementation of PRF(secret, label, seed).
///
/// # Arguments
///
/// * `outer_state` - The outer state of HMAC-SHA256.
/// * `inner_state` - The inner state of HMAC-SHA256.
/// * `seed` - The seed to use.
/// * `label` - The label to use.
/// * `bytes` - The number of bytes to output.
pub fn prf(
outer_state: [u32; 8],
inner_state: [u32; 8],
seed: &[u8],
label: &[u8],
bytes: usize,
) -> Vec<u8> {
let iterations = bytes / 32 + (bytes % 32 != 0) as usize;
let mut label_seed = label.to_vec();
label_seed.extend_from_slice(seed);
let mut output = p_hash(outer_state, inner_state, &label_seed, iterations);
output.truncate(bytes);
output
}
#[cfg(test)]
mod tests {
use mpz_circuits::{evaluate, CircuitBuilder};
use crate::hmac_sha256::hmac_sha256_partial;
use super::*;
#[test]
fn test_p_hash() {
let builder = CircuitBuilder::new();
let outer_state = builder.add_array_input::<u32, 8>();
let inner_state = builder.add_array_input::<u32, 8>();
let seed = builder.add_array_input::<u8, 64>();
let output = p_hash_trace(builder.state(), outer_state, inner_state, &seed, 2);
builder.add_output(output);
let circ = builder.build().unwrap();
let outer_state = [0u32; 8];
let inner_state = [1u32; 8];
let seed = [42u8; 64];
let expected = p_hash(outer_state, inner_state, &seed, 2);
let actual = evaluate!(circ, fn(outer_state, inner_state, &seed) -> Vec<u8>).unwrap();
assert_eq!(actual, expected);
}
#[test]
fn test_prf() {
let builder = CircuitBuilder::new();
let outer_state = builder.add_array_input::<u32, 8>();
let inner_state = builder.add_array_input::<u32, 8>();
let seed = builder.add_array_input::<u8, 64>();
let label = builder.add_array_input::<u8, 13>();
let output = prf_trace(builder.state(), outer_state, inner_state, &seed, &label, 48);
builder.add_output(output);
let circ = builder.build().unwrap();
let master_secret = [0u8; 48];
let seed = [43u8; 64];
let label = b"master secret";
let (outer_state, inner_state) = hmac_sha256_partial(&master_secret);
let expected = prf(outer_state, inner_state, &seed, label, 48);
let actual =
evaluate!(circ, fn(outer_state, inner_state, &seed, label) -> Vec<u8>).unwrap();
assert_eq!(actual, expected);
let mut expected_ring = [0u8; 48];
ring_prf::prf(&mut expected_ring, &master_secret, label, &seed);
assert_eq!(actual, expected_ring);
}
// Borrowed from Rustls for testing
// https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs
mod ring_prf {
use ring::{hmac, hmac::HMAC_SHA256};
fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag {
let mut ctx = hmac::Context::with_key(key);
ctx.update(a);
ctx.update(b);
ctx.sign()
}
fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) {
let hmac_key = hmac::Key::new(HMAC_SHA256, secret);
// A(1)
let mut current_a = hmac::sign(&hmac_key, seed);
let chunk_size = HMAC_SHA256.digest_algorithm().output_len();
for chunk in out.chunks_mut(chunk_size) {
// P_hash[i] = HMAC_hash(secret, A(i) + seed)
let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed);
chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]);
// A(i+1) = HMAC_hash(secret, A(i))
current_a = hmac::sign(&hmac_key, current_a.as_ref());
}
}
fn concat(a: &[u8], b: &[u8]) -> Vec<u8> {
let mut ret = Vec::new();
ret.extend_from_slice(a);
ret.extend_from_slice(b);
ret
}
pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) {
let joined_seed = concat(label, seed);
p(out, secret, &joined_seed);
}
}
}

View File

@@ -1,200 +0,0 @@
use std::cell::RefCell;
use mpz_circuits::{
types::{U32, U8},
BuilderState, Tracer,
};
use crate::{
hmac_sha256::{hmac_sha256_partial, hmac_sha256_partial_trace},
prf::{prf, prf_trace},
};
/// Session Keys.
///
/// Computes expanded p1 which consists of client_write_key + server_write_key.
/// Computes expanded p2 which consists of client_IV + server_IV.
///
/// # Arguments
///
/// * `builder_state` - Reference to builder state.
/// * `pms` - 32-byte premaster secret.
/// * `client_random` - 32-byte client random.
/// * `server_random` - 32-byte server random.
///
/// # Returns
///
/// * `client_write_key` - 16-byte client write key.
/// * `server_write_key` - 16-byte server write key.
/// * `client_IV` - 4-byte client IV.
/// * `server_IV` - 4-byte server IV.
/// * `outer_hash_state` - 256-bit master-secret outer HMAC state.
/// * `inner_hash_state` - 256-bit master-secret inner HMAC state.
#[allow(clippy::type_complexity)]
pub fn session_keys_trace<'a>(
builder_state: &'a RefCell<BuilderState>,
pms: [Tracer<'a, U8>; 32],
client_random: [Tracer<'a, U8>; 32],
server_random: [Tracer<'a, U8>; 32],
) -> (
[Tracer<'a, U8>; 16],
[Tracer<'a, U8>; 16],
[Tracer<'a, U8>; 4],
[Tracer<'a, U8>; 4],
[Tracer<'a, U32>; 8],
[Tracer<'a, U32>; 8],
) {
let (pms_outer_state, pms_inner_state) = hmac_sha256_partial_trace(builder_state, &pms);
let master_secret = {
let seed = client_random
.iter()
.chain(&server_random)
.copied()
.collect::<Vec<_>>();
let label = b"master secret"
.map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v)));
prf_trace(
builder_state,
pms_outer_state,
pms_inner_state,
&seed,
&label,
48,
)
};
let (master_secret_outer_state, master_secret_inner_state) =
hmac_sha256_partial_trace(builder_state, &master_secret);
let key_material = {
let seed = server_random
.iter()
.chain(&client_random)
.copied()
.collect::<Vec<_>>();
let label = b"key expansion"
.map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v)));
prf_trace(
builder_state,
master_secret_outer_state,
master_secret_inner_state,
&seed,
&label,
40,
)
};
let cwk = key_material[0..16].try_into().unwrap();
let swk = key_material[16..32].try_into().unwrap();
let civ = key_material[32..36].try_into().unwrap();
let siv = key_material[36..40].try_into().unwrap();
(
cwk,
swk,
civ,
siv,
master_secret_outer_state,
master_secret_inner_state,
)
}
/// Reference implementation of session keys derivation.
pub fn session_keys(
pms: [u8; 32],
client_random: [u8; 32],
server_random: [u8; 32],
) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4]) {
let (pms_outer_state, pms_inner_state) = hmac_sha256_partial(&pms);
let master_secret = {
let seed = client_random
.iter()
.chain(&server_random)
.copied()
.collect::<Vec<_>>();
let label = b"master secret";
prf(pms_outer_state, pms_inner_state, &seed, label, 48)
};
let (master_secret_outer_state, master_secret_inner_state) =
hmac_sha256_partial(&master_secret);
let key_material = {
let seed = server_random
.iter()
.chain(&client_random)
.copied()
.collect::<Vec<_>>();
let label = b"key expansion";
prf(
master_secret_outer_state,
master_secret_inner_state,
&seed,
label,
40,
)
};
let cwk = key_material[0..16].try_into().unwrap();
let swk = key_material[16..32].try_into().unwrap();
let civ = key_material[32..36].try_into().unwrap();
let siv = key_material[36..40].try_into().unwrap();
(cwk, swk, civ, siv)
}
#[cfg(test)]
mod tests {
use mpz_circuits::{evaluate, CircuitBuilder};
use super::*;
#[test]
fn test_session_keys() {
let builder = CircuitBuilder::new();
let pms = builder.add_array_input::<u8, 32>();
let client_random = builder.add_array_input::<u8, 32>();
let server_random = builder.add_array_input::<u8, 32>();
let (cwk, swk, civ, siv, outer_state, inner_state) =
session_keys_trace(builder.state(), pms, client_random, server_random);
builder.add_output(cwk);
builder.add_output(swk);
builder.add_output(civ);
builder.add_output(siv);
builder.add_output(outer_state);
builder.add_output(inner_state);
let circ = builder.build().unwrap();
let pms = [0u8; 32];
let client_random = [42u8; 32];
let server_random = [69u8; 32];
let (expected_cwk, expected_swk, expected_civ, expected_siv) =
session_keys(pms, client_random, server_random);
let (cwk, swk, civ, siv, _, _) = evaluate!(
circ,
fn(
pms,
client_random,
server_random,
) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4], [u32; 8], [u32; 8])
)
.unwrap();
assert_eq!(cwk, expected_cwk);
assert_eq!(swk, expected_swk);
assert_eq!(civ, expected_civ);
assert_eq!(siv, expected_siv);
}
}

View File

@@ -1,88 +0,0 @@
use std::cell::RefCell;
use mpz_circuits::{
types::{U32, U8},
BuilderState, Tracer,
};
use crate::prf::{prf, prf_trace};
/// Computes verify_data as specified in RFC 5246, Section 7.4.9.
///
/// verify_data
/// PRF(master_secret, finished_label,
/// Hash(handshake_messages))[0..verify_data_length-1];
///
/// # Arguments
///
/// * `builder_state` - The builder state.
/// * `outer_state` - The outer HMAC state of the master secret.
/// * `inner_state` - The inner HMAC state of the master secret.
/// * `label` - The label to use.
/// * `hs_hash` - The handshake hash.
pub fn verify_data_trace<'a>(
builder_state: &'a RefCell<BuilderState>,
outer_state: [Tracer<'a, U32>; 8],
inner_state: [Tracer<'a, U32>; 8],
label: &[Tracer<'a, U8>],
hs_hash: [Tracer<'a, U8>; 32],
) -> [Tracer<'a, U8>; 12] {
let vd = prf_trace(builder_state, outer_state, inner_state, &hs_hash, label, 12);
vd.try_into().expect("vd is 12 bytes")
}
/// Reference implementation of verify_data as specified in RFC 5246, Section
/// 7.4.9.
///
/// # Arguments
///
/// * `outer_state` - The outer HMAC state of the master secret.
/// * `inner_state` - The inner HMAC state of the master secret.
/// * `label` - The label to use.
/// * `hs_hash` - The handshake hash.
pub fn verify_data(
outer_state: [u32; 8],
inner_state: [u32; 8],
label: &[u8],
hs_hash: [u8; 32],
) -> [u8; 12] {
let vd = prf(outer_state, inner_state, &hs_hash, label, 12);
vd.try_into().expect("vd is 12 bytes")
}
#[cfg(test)]
mod tests {
use super::*;
use mpz_circuits::{evaluate, CircuitBuilder};
const CF_LABEL: &[u8; 15] = b"client finished";
#[test]
fn test_verify_data() {
let builder = CircuitBuilder::new();
let outer_state = builder.add_array_input::<u32, 8>();
let inner_state = builder.add_array_input::<u32, 8>();
let label = builder.add_array_input::<u8, 15>();
let hs_hash = builder.add_array_input::<u8, 32>();
let vd = verify_data_trace(builder.state(), outer_state, inner_state, &label, hs_hash);
builder.add_output(vd);
let circ = builder.build().unwrap();
let outer_state = [0u32; 8];
let inner_state = [1u32; 8];
let hs_hash = [42u8; 32];
let expected = prf(outer_state, inner_state, &hs_hash, CF_LABEL, 12);
let actual = evaluate!(
circ,
fn(outer_state, inner_state, CF_LABEL, hs_hash) -> [u8; 12]
)
.unwrap();
assert_eq!(actual.to_vec(), expected);
}
}

View File

@@ -14,22 +14,15 @@ workspace = true
[lib]
name = "hmac_sha256"
[features]
default = ["mock"]
rayon = ["mpz-common/rayon"]
mock = []
[dependencies]
tlsn-hmac-sha256-circuits = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-common = { workspace = true, features = ["cpu"] }
mpz-hash = { workspace = true }
derive_builder = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
futures = { workspace = true }
sha2 = { workspace = true }
[dev-dependencies]
mpz-ot = { workspace = true, features = ["ideal"] }
@@ -39,7 +32,8 @@ mpz-common = { workspace = true, features = ["test-utils"] }
criterion = { workspace = true, features = ["async_tokio"] }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
rand06-compat = { workspace = true }
hex = { workspace = true }
ring = { workspace = true }
[[bench]]
name = "prf"

View File

@@ -2,16 +2,16 @@
use criterion::{criterion_group, criterion_main, Criterion};
use hmac_sha256::{MpcPrf, PrfConfig, Role};
use hmac_sha256::{Mode, MpcPrf};
use mpz_common::context::test_mt_context;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array},
prelude::*,
Execute,
};
use rand::{rngs::StdRng, SeedableRng};
use rand06_compat::Rand0_6CompatExt;
#[allow(clippy::unit_arg)]
fn criterion_benchmark(c: &mut Criterion) {
@@ -36,10 +36,10 @@ async fn prf() {
let mut leader_ctx = leader_exec.new_context().await.unwrap();
let mut follower_ctx = follower_exec.new_context().await.unwrap();
let delta = Delta::random(&mut rng.compat_by_ref());
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
@@ -52,23 +52,17 @@ async fn prf() {
follower_vm.assign(follower_pms, pms).unwrap();
follower_vm.commit(follower_pms).unwrap();
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
let mut leader = MpcPrf::new(Mode::default());
let mut follower = MpcPrf::new(Mode::default());
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
leader
.set_client_random(&mut leader_vm, Some(client_random))
.unwrap();
follower.set_client_random(&mut follower_vm, None).unwrap();
leader.set_client_random(client_random).unwrap();
follower.set_client_random(client_random).unwrap();
leader
.set_server_random(&mut leader_vm, server_random)
.unwrap();
follower
.set_server_random(&mut follower_vm, server_random)
.unwrap();
leader.set_server_random(server_random).unwrap();
follower.set_server_random(server_random).unwrap();
let _ = leader_vm
.decode(leader_output.keys.client_write_key)
@@ -88,44 +82,61 @@ async fn prf() {
let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap();
futures::join!(
async {
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
while leader.wants_flush() || follower.wants_flush() {
tokio::try_join!(
async {
leader.flush(&mut leader_vm).unwrap();
leader_vm.execute_all(&mut leader_ctx).await
},
async {
follower.flush(&mut follower_vm).unwrap();
follower_vm.execute_all(&mut follower_ctx).await
}
)
.unwrap();
}
let cf_hs_hash = [1u8; 32];
let sf_hs_hash = [2u8; 32];
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
leader.set_cf_hash(cf_hs_hash).unwrap();
follower.set_cf_hash(cf_hs_hash).unwrap();
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
while leader.wants_flush() || follower.wants_flush() {
tokio::try_join!(
async {
leader.flush(&mut leader_vm).unwrap();
leader_vm.execute_all(&mut leader_ctx).await
},
async {
follower.flush(&mut follower_vm).unwrap();
follower_vm.execute_all(&mut follower_ctx).await
}
)
.unwrap();
}
let _ = leader_vm.decode(leader_output.cf_vd).unwrap();
let _ = leader_vm.decode(leader_output.sf_vd).unwrap();
let _ = follower_vm.decode(follower_output.cf_vd).unwrap();
let _ = follower_vm.decode(follower_output.sf_vd).unwrap();
futures::join!(
async {
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
let sf_hs_hash = [2u8; 32];
leader.set_sf_hash(sf_hs_hash).unwrap();
follower.set_sf_hash(sf_hs_hash).unwrap();
while leader.wants_flush() || follower.wants_flush() {
tokio::try_join!(
async {
leader.flush(&mut leader_vm).unwrap();
leader_vm.execute_all(&mut leader_ctx).await
},
async {
follower.flush(&mut follower_vm).unwrap();
follower_vm.execute_all(&mut follower_ctx).await
}
)
.unwrap();
}
let _ = leader_vm.decode(leader_output.sf_vd).unwrap();
let _ = follower_vm.decode(follower_output.sf_vd).unwrap();
}

View File

@@ -1,24 +1,16 @@
use derive_builder::Builder;
//! PRF modes.
/// Role of this party in the PRF.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
/// The leader provides the private inputs to the PRF.
Leader,
/// The follower is blind to the inputs to the PRF.
Follower,
/// Modes for the PRF.
#[derive(Debug, Clone, Copy)]
pub enum Mode {
/// Computes some hashes locally.
Reduced,
/// Computes the whole PRF in MPC.
Normal,
}
/// Configuration for the PRF.
#[derive(Debug, Builder)]
pub struct PrfConfig {
/// The role of this party in the PRF.
pub(crate) role: Role,
}
impl PrfConfig {
/// Creates a new builder.
pub fn builder() -> PrfConfigBuilder {
PrfConfigBuilder::default()
impl Default for Mode {
fn default() -> Self {
Self::Reduced
}
}

View File

@@ -1,6 +1,8 @@
use core::fmt;
use std::error::Error;
use mpz_hash::sha256::Sha256Error;
/// A PRF error.
#[derive(Debug, thiserror::Error)]
pub struct PrfError {
@@ -20,22 +22,21 @@ impl PrfError {
}
}
pub(crate) fn vm<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> Self {
Self::new(ErrorKind::Vm, err)
}
pub(crate) fn state(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::State,
source: Some(msg.into().into()),
}
}
}
pub(crate) fn role(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Role,
source: Some(msg.into().into()),
}
}
pub(crate) fn vm<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> Self {
Self::new(ErrorKind::Vm, err)
impl From<Sha256Error> for PrfError {
fn from(value: Sha256Error) -> Self {
Self::new(ErrorKind::Hash, value)
}
}
@@ -43,7 +44,7 @@ impl PrfError {
pub(crate) enum ErrorKind {
Vm,
State,
Role,
Hash,
}
impl fmt::Display for PrfError {
@@ -51,7 +52,7 @@ impl fmt::Display for PrfError {
match self.kind {
ErrorKind::Vm => write!(f, "vm error")?,
ErrorKind::State => write!(f, "state error")?,
ErrorKind::Role => write!(f, "role error")?,
ErrorKind::Hash => write!(f, "hash error")?,
}
if let Some(ref source) = self.source {
@@ -61,9 +62,3 @@ impl fmt::Display for PrfError {
Ok(())
}
}
impl From<mpz_common::ContextError> for PrfError {
fn from(error: mpz_common::ContextError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}

View File

@@ -0,0 +1,177 @@
//! Computation of HMAC-SHA256.
//!
//! HMAC-SHA256 is defined as
//!
//! HMAC(m) = H((key' xor opad) || H((key' xor ipad) || m))
//!
//! * H - SHA256 hash function
//! * key' - key padded with zero bytes to 64 bytes (we do not support longer
//! keys)
//! * opad - 64 bytes of 0x5c
//! * ipad - 64 bytes of 0x36
//! * m - message
//!
//! This implementation computes HMAC-SHA256 using intermediate results
//! `outer_partial` and `inner_local`. Then HMAC(m) = H(outer_partial ||
//! inner_local)
//!
//! * `outer_partial` - key' xor opad
//! * `inner_local` - H((key' xor ipad) || m)
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array,
},
Vm,
};
use crate::PrfError;
pub(crate) const IPAD: [u8; 64] = [0x36; 64];
pub(crate) const OPAD: [u8; 64] = [0x5c; 64];
/// Computes HMAC-SHA256
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `outer_partial` - (key' xor opad)
/// * `inner_local` - H((key' xor ipad) || m)
pub(crate) fn hmac_sha256(
vm: &mut dyn Vm<Binary>,
mut outer_partial: Sha256,
inner_local: Array<U8, 32>,
) -> Result<Array<U8, 32>, PrfError> {
outer_partial.update(&inner_local);
outer_partial.compress(vm)?;
outer_partial.finalize(vm).map_err(PrfError::from)
}
#[cfg(test)]
mod tests {
use crate::{
hmac::hmac_sha256,
sha256, state_to_bytes,
test_utils::{compute_inner_local, compute_outer_partial, mock_vm},
};
use mpz_common::context::test_st_context;
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{U32, U8},
Array, MemoryExt, ViewExt,
},
Execute,
};
#[test]
fn test_hmac_reference() {
let (inputs, references) = test_fixtures();
for (input, &reference) in inputs.iter().zip(references.iter()) {
let outer_partial = compute_outer_partial(input.0.clone());
let inner_local = compute_inner_local(input.0.clone(), &input.1);
let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local));
assert_eq!(state_to_bytes(hmac), reference);
}
}
#[tokio::test]
async fn test_hmac_circuit() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let (inputs, references) = test_fixtures();
for (input, &reference) in inputs.iter().zip(references.iter()) {
let outer_partial = compute_outer_partial(input.0.clone());
let inner_local = compute_inner_local(input.0.clone(), &input.1);
let outer_partial_leader: Array<U32, 8> = leader.alloc().unwrap();
leader.mark_public(outer_partial_leader).unwrap();
leader.assign(outer_partial_leader, outer_partial).unwrap();
leader.commit(outer_partial_leader).unwrap();
let inner_local_leader: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(inner_local_leader).unwrap();
leader
.assign(inner_local_leader, state_to_bytes(inner_local))
.unwrap();
leader.commit(inner_local_leader).unwrap();
let hmac_leader = hmac_sha256(
&mut leader,
Sha256::new_from_state(outer_partial_leader, 1),
inner_local_leader,
)
.unwrap();
let hmac_leader = leader.decode(hmac_leader).unwrap();
let outer_partial_follower: Array<U32, 8> = follower.alloc().unwrap();
follower.mark_public(outer_partial_follower).unwrap();
follower
.assign(outer_partial_follower, outer_partial)
.unwrap();
follower.commit(outer_partial_follower).unwrap();
let inner_local_follower: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(inner_local_follower).unwrap();
follower
.assign(inner_local_follower, state_to_bytes(inner_local))
.unwrap();
follower.commit(inner_local_follower).unwrap();
let hmac_follower = hmac_sha256(
&mut follower,
Sha256::new_from_state(outer_partial_follower, 1),
inner_local_follower,
)
.unwrap();
let hmac_follower = follower.decode(hmac_follower).unwrap();
let (hmac_leader, hmac_follower) = tokio::try_join!(
async {
leader.execute_all(&mut ctx_a).await.unwrap();
hmac_leader.await
},
async {
follower.execute_all(&mut ctx_b).await.unwrap();
hmac_follower.await
}
)
.unwrap();
assert_eq!(hmac_leader, hmac_follower);
assert_eq!(hmac_leader, reference);
}
}
#[allow(clippy::type_complexity)]
fn test_fixtures() -> (Vec<(Vec<u8>, Vec<u8>)>, Vec<[u8; 32]>) {
let test_vectors: Vec<(Vec<u8>, Vec<u8>)> = vec![
(
hex::decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b").unwrap(),
hex::decode("4869205468657265").unwrap(),
),
(
hex::decode("4a656665").unwrap(),
hex::decode("7768617420646f2079612077616e7420666f72206e6f7468696e673f").unwrap(),
),
];
let expected: Vec<[u8; 32]> = vec![
hex::decode("b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7")
.unwrap()
.try_into()
.unwrap(),
hex::decode("5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843")
.unwrap()
.try_into()
.unwrap(),
];
(test_vectors, expected)
}
}

View File

@@ -1,30 +1,24 @@
//! This module contains the protocol for computing TLS SHA-256 HMAC PRF.
//! This crate contains the protocol for computing TLS 1.2 SHA-256 HMAC PRF.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod config;
mod error;
mod prf;
mod hmac;
#[cfg(test)]
mod test_utils;
pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role};
mod config;
pub use config::Mode;
mod error;
pub use error::PrfError;
mod prf;
pub use prf::MpcPrf;
use mpz_vm_core::memory::{binary::U8, Array};
pub(crate) static CF_LABEL: &[u8] = b"client finished";
pub(crate) static SF_LABEL: &[u8] = b"server finished";
/// Builds the circuits for the PRF.
///
/// This function can be used ahead of time to build the circuits for the PRF,
/// which at the moment is CPU and memory intensive.
pub async fn build_circuits() {
prf::Circuits::get().await;
}
/// PRF output.
#[derive(Debug, Clone, Copy)]
pub struct PrfOutput {
@@ -49,176 +43,227 @@ pub struct SessionKeys {
pub server_iv: Array<U8, 4>,
}
fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] {
use sha2::{
compress256,
digest::{
block_buffer::{BlockBuffer, Eager},
generic_array::typenum::U64,
},
};
let mut buffer = BlockBuffer::<U64, Eager>::default();
buffer.digest_blocks(msg, |b| compress256(&mut state, b));
buffer.digest_pad(0x80, &(((msg.len() + pos) * 8) as u64).to_be_bytes(), |b| {
compress256(&mut state, &[*b])
});
state
}
fn state_to_bytes(input: [u32; 8]) -> [u8; 32] {
let mut output = [0_u8; 32];
for (k, byte_chunk) in input.iter().enumerate() {
let byte_chunk = byte_chunk.to_be_bytes();
output[4 * k..4 * (k + 1)].copy_from_slice(&byte_chunk);
}
output
}
#[cfg(test)]
mod tests {
use crate::{
test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
Mode, MpcPrf, SessionKeys,
};
use mpz_common::context::test_st_context;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{memory::correlated::Delta, prelude::*};
use rand::{rngs::StdRng, SeedableRng};
use rand06_compat::Rand0_6CompatExt;
use super::*;
fn compute_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
let (outer_state, inner_state) = hmac_sha256_partial(&pms);
let seed = client_random
.iter()
.chain(&server_random)
.copied()
.collect::<Vec<_>>();
let ms = prf(outer_state, inner_state, &seed, b"master secret", 48);
ms.try_into().unwrap()
}
fn compute_vd(ms: [u8; 48], label: &[u8], hs_hash: [u8; 32]) -> [u8; 12] {
let (outer_state, inner_state) = hmac_sha256_partial(&ms);
let vd = prf(outer_state, inner_state, &hs_hash, label, 12);
vd.try_into().unwrap()
}
#[ignore = "expensive"]
#[tokio::test]
async fn test_prf() {
let mut rng = StdRng::seed_from_u64(0);
async fn test_prf_reduced() {
let mode = Mode::Reduced;
test_prf(mode).await;
}
let pms = [42u8; 32];
let client_random = [69u8; 32];
let server_random: [u8; 32] = [96u8; 32];
let ms = compute_ms(pms, client_random, server_random);
#[tokio::test]
async fn test_prf_normal() {
let mode = Mode::Normal;
test_prf(mode).await;
}
let (mut leader_ctx, mut follower_ctx) = test_st_context(128);
async fn test_prf(mode: Mode) {
let mut rng = StdRng::seed_from_u64(1);
// Test input
let pms: [u8; 32] = rng.random();
let client_random: [u8; 32] = rng.random();
let server_random: [u8; 32] = rng.random();
let delta = Delta::random(&mut rng.compat_by_ref());
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
let cf_hs_hash: [u8; 32] = rng.random();
let sf_hs_hash: [u8; 32] = rng.random();
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
// Expected output
let ms_expected = prf_ms(pms, client_random, server_random);
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
leader_vm.mark_public(leader_pms).unwrap();
leader_vm.assign(leader_pms, pms).unwrap();
leader_vm.commit(leader_pms).unwrap();
let [cwk_expected, swk_expected, civ_expected, siv_expected] =
prf_keys(ms_expected, client_random, server_random);
let follower_pms: Array<U8, 32> = follower_vm.alloc().unwrap();
follower_vm.mark_public(follower_pms).unwrap();
follower_vm.assign(follower_pms, pms).unwrap();
follower_vm.commit(follower_pms).unwrap();
let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap();
let swk_expected: [u8; 16] = swk_expected.try_into().unwrap();
let civ_expected: [u8; 4] = civ_expected.try_into().unwrap();
let siv_expected: [u8; 4] = siv_expected.try_into().unwrap();
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash);
let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash);
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap();
let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap();
leader
.set_client_random(&mut leader_vm, Some(client_random))
// Set up vm and prf
let (mut ctx_a, mut ctx_b) = test_st_context(128);
let (mut leader, mut follower) = mock_vm();
let leader_pms: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(leader_pms).unwrap();
leader.assign(leader_pms, pms).unwrap();
leader.commit(leader_pms).unwrap();
let follower_pms: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(follower_pms).unwrap();
follower.assign(follower_pms, pms).unwrap();
follower.commit(follower_pms).unwrap();
let mut prf_leader = MpcPrf::new(mode);
let mut prf_follower = MpcPrf::new(mode);
let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap();
let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap();
// client_random and server_random
prf_leader.set_client_random(client_random).unwrap();
prf_follower.set_client_random(client_random).unwrap();
prf_leader.set_server_random(server_random).unwrap();
prf_follower.set_server_random(server_random).unwrap();
let SessionKeys {
client_write_key: cwk_leader,
server_write_key: swk_leader,
client_iv: civ_leader,
server_iv: siv_leader,
} = leader_prf_out.keys;
let mut cwk_leader = leader.decode(cwk_leader).unwrap();
let mut swk_leader = leader.decode(swk_leader).unwrap();
let mut civ_leader = leader.decode(civ_leader).unwrap();
let mut siv_leader = leader.decode(siv_leader).unwrap();
let SessionKeys {
client_write_key: cwk_follower,
server_write_key: swk_follower,
client_iv: civ_follower,
server_iv: siv_follower,
} = follower_prf_out.keys;
let mut cwk_follower = follower.decode(cwk_follower).unwrap();
let mut swk_follower = follower.decode(swk_follower).unwrap();
let mut civ_follower = follower.decode(civ_follower).unwrap();
let mut siv_follower = follower.decode(siv_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
follower.set_client_random(&mut follower_vm, None).unwrap();
}
leader
.set_server_random(&mut leader_vm, server_random)
let cwk_leader = cwk_leader.try_recv().unwrap().unwrap();
let swk_leader = swk_leader.try_recv().unwrap().unwrap();
let civ_leader = civ_leader.try_recv().unwrap().unwrap();
let siv_leader = siv_leader.try_recv().unwrap().unwrap();
let cwk_follower = cwk_follower.try_recv().unwrap().unwrap();
let swk_follower = swk_follower.try_recv().unwrap().unwrap();
let civ_follower = civ_follower.try_recv().unwrap().unwrap();
let siv_follower = siv_follower.try_recv().unwrap().unwrap();
assert_eq!(cwk_leader, cwk_follower);
assert_eq!(swk_leader, swk_follower);
assert_eq!(civ_leader, civ_follower);
assert_eq!(siv_leader, siv_follower);
assert_eq!(cwk_leader, cwk_expected);
assert_eq!(swk_leader, swk_expected);
assert_eq!(civ_leader, civ_expected);
assert_eq!(siv_leader, siv_expected);
// client finished
prf_leader.set_cf_hash(cf_hs_hash).unwrap();
prf_follower.set_cf_hash(cf_hs_hash).unwrap();
let cf_vd_leader = leader_prf_out.cf_vd;
let cf_vd_follower = follower_prf_out.cf_vd;
let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap();
let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
follower
.set_server_random(&mut follower_vm, server_random)
}
let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap();
let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap();
assert_eq!(cf_vd_leader, cf_vd_follower);
assert_eq!(cf_vd_leader, cf_vd_expected);
// server finished
prf_leader.set_sf_hash(sf_hs_hash).unwrap();
prf_follower.set_sf_hash(sf_hs_hash).unwrap();
let sf_vd_leader = leader_prf_out.sf_vd;
let sf_vd_follower = follower_prf_out.sf_vd;
let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap();
let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
let leader_cwk = leader_vm
.decode(leader_output.keys.client_write_key)
.unwrap();
let leader_swk = leader_vm
.decode(leader_output.keys.server_write_key)
.unwrap();
let leader_civ = leader_vm.decode(leader_output.keys.client_iv).unwrap();
let leader_siv = leader_vm.decode(leader_output.keys.server_iv).unwrap();
let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap();
let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap();
let follower_cwk = follower_vm
.decode(follower_output.keys.client_write_key)
.unwrap();
let follower_swk = follower_vm
.decode(follower_output.keys.server_write_key)
.unwrap();
let follower_civ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
let follower_siv = follower_vm.decode(follower_output.keys.server_iv).unwrap();
futures::join!(
async {
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
let leader_cwk = leader_cwk.await.unwrap();
let leader_swk = leader_swk.await.unwrap();
let leader_civ = leader_civ.await.unwrap();
let leader_siv = leader_siv.await.unwrap();
let follower_cwk = follower_cwk.await.unwrap();
let follower_swk = follower_swk.await.unwrap();
let follower_civ = follower_civ.await.unwrap();
let follower_siv = follower_siv.await.unwrap();
let (expected_cwk, expected_swk, expected_civ, expected_siv) =
session_keys(pms, client_random, server_random);
assert_eq!(leader_cwk, expected_cwk);
assert_eq!(leader_swk, expected_swk);
assert_eq!(leader_civ, expected_civ);
assert_eq!(leader_siv, expected_siv);
assert_eq!(follower_cwk, expected_cwk);
assert_eq!(follower_swk, expected_swk);
assert_eq!(follower_civ, expected_civ);
assert_eq!(follower_siv, expected_siv);
let cf_hs_hash = [1u8; 32];
let sf_hs_hash = [2u8; 32];
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
let leader_cf_vd = leader_vm.decode(leader_output.cf_vd).unwrap();
let leader_sf_vd = leader_vm.decode(leader_output.sf_vd).unwrap();
let follower_cf_vd = follower_vm.decode(follower_output.cf_vd).unwrap();
let follower_sf_vd = follower_vm.decode(follower_output.sf_vd).unwrap();
futures::join!(
async {
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
let leader_cf_vd = leader_cf_vd.await.unwrap();
let leader_sf_vd = leader_sf_vd.await.unwrap();
let follower_cf_vd = follower_cf_vd.await.unwrap();
let follower_sf_vd = follower_sf_vd.await.unwrap();
let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash);
let expected_sf_vd = compute_vd(ms, b"server finished", sf_hs_hash);
assert_eq!(leader_cf_vd, expected_cf_vd);
assert_eq!(leader_sf_vd, expected_sf_vd);
assert_eq!(follower_cf_vd, expected_cf_vd);
assert_eq!(follower_sf_vd, expected_sf_vd);
assert_eq!(sf_vd_leader, sf_vd_follower);
assert_eq!(sf_vd_leader, sf_vd_expected);
}
}

View File

@@ -1,98 +1,41 @@
use std::{
fmt::Debug,
sync::{Arc, OnceLock},
use crate::{
hmac::{IPAD, OPAD},
Mode, PrfError, PrfOutput,
};
use hmac_sha256_circuits::{build_session_keys, build_verify_data};
use mpz_circuits::Circuit;
use mpz_common::cpu::CpuBackend;
use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder};
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U32, U8},
Array,
binary::{Binary, U8},
Array, MemoryExt, StaticSize, Vector, ViewExt,
},
prelude::*,
Call, Vm,
Call, CallableExt, Vm,
};
use std::{fmt::Debug, sync::Arc};
use tracing::instrument;
use crate::{PrfConfig, PrfError, PrfOutput, Role, SessionKeys, CF_LABEL, SF_LABEL};
mod state;
use state::State;
pub(crate) struct Circuits {
session_keys: Arc<Circuit>,
client_vd: Arc<Circuit>,
server_vd: Arc<Circuit>,
}
impl Circuits {
pub(crate) async fn get() -> &'static Self {
static CIRCUITS: OnceLock<Circuits> = OnceLock::new();
if let Some(circuits) = CIRCUITS.get() {
return circuits;
}
let (session_keys, client_vd, server_vd) = futures::join!(
CpuBackend::blocking(build_session_keys),
CpuBackend::blocking(|| build_verify_data(CF_LABEL)),
CpuBackend::blocking(|| build_verify_data(SF_LABEL)),
);
_ = CIRCUITS.set(Circuits {
session_keys,
client_vd,
server_vd,
});
CIRCUITS.get().unwrap()
}
}
mod function;
use function::Prf;
/// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF.
#[derive(Debug)]
pub(crate) enum State {
Initialized,
SessionKeys {
client_random: Array<U8, 32>,
server_random: Array<U8, 32>,
cf_hash: Array<U8, 32>,
sf_hash: Array<U8, 32>,
},
ClientFinished {
cf_hash: Array<U8, 32>,
sf_hash: Array<U8, 32>,
},
ServerFinished {
sf_hash: Array<U8, 32>,
},
Complete,
Error,
}
impl State {
fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
}
/// MPC PRF for computing TLS HMAC-SHA256 PRF.
pub struct MpcPrf {
config: PrfConfig,
mode: Mode,
state: State,
}
impl Debug for MpcPrf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MpcPrf")
.field("config", &self.config)
.field("state", &self.state)
.finish()
}
}
impl MpcPrf {
/// Creates a new instance of the PRF.
pub fn new(config: PrfConfig) -> MpcPrf {
MpcPrf {
config,
///
/// # Arguments
///
/// `mode` - The PRF mode.
pub fn new(mode: Mode) -> MpcPrf {
Self {
mode,
state: State::Initialized,
}
}
@@ -113,122 +56,58 @@ impl MpcPrf {
return Err(PrfError::state("PRF not in initialized state"));
};
let circuits = futures::executor::block_on(Circuits::get());
let mode = self.mode;
let pms: Vector<U8> = pms.into();
let client_random = vm.alloc().map_err(PrfError::vm)?;
let server_random = vm.alloc().map_err(PrfError::vm)?;
let outer_partial_pms = compute_partial(vm, pms, OPAD)?;
let inner_partial_pms = compute_partial(vm, pms, IPAD)?;
// The client random is kept private so that the handshake transcript
// hashes do not leak information about the server's identity.
match self.config.role {
Role::Leader => vm.mark_private(client_random),
Role::Follower => vm.mark_blind(client_random),
}
.map_err(PrfError::vm)?;
let master_secret =
Prf::alloc_master_secret(mode, vm, outer_partial_pms, inner_partial_pms)?;
let ms = master_secret.output();
let ms = merge_outputs(vm, ms, 48)?;
vm.mark_public(server_random).map_err(PrfError::vm)?;
let outer_partial_ms = compute_partial(vm, ms, OPAD)?;
let inner_partial_ms = compute_partial(vm, ms, IPAD)?;
#[allow(clippy::type_complexity)]
let (
client_write_key,
server_write_key,
client_iv,
server_iv,
ms_outer_hash_state,
ms_inner_hash_state,
): (
Array<U8, 16>,
Array<U8, 16>,
Array<U8, 4>,
Array<U8, 4>,
Array<U32, 8>,
Array<U32, 8>,
) = vm
.call(
Call::builder(circuits.session_keys.clone())
.arg(pms)
.arg(client_random)
.arg(server_random)
.build()
.map_err(PrfError::vm)?,
)
.map_err(PrfError::vm)?;
let keys = SessionKeys {
client_write_key,
server_write_key,
client_iv,
server_iv,
};
let cf_hash = vm.alloc().map_err(PrfError::vm)?;
vm.mark_public(cf_hash).map_err(PrfError::vm)?;
let cf_vd = vm
.call(
Call::builder(circuits.client_vd.clone())
.arg(ms_outer_hash_state)
.arg(ms_inner_hash_state)
.arg(cf_hash)
.build()
.map_err(PrfError::vm)?,
)
.map_err(PrfError::vm)?;
let sf_hash = vm.alloc().map_err(PrfError::vm)?;
vm.mark_public(sf_hash).map_err(PrfError::vm)?;
let sf_vd = vm
.call(
Call::builder(circuits.server_vd.clone())
.arg(ms_outer_hash_state)
.arg(ms_inner_hash_state)
.arg(sf_hash)
.build()
.map_err(PrfError::vm)?,
)
.map_err(PrfError::vm)?;
let key_expansion =
Prf::alloc_key_expansion(mode, vm, outer_partial_ms.clone(), inner_partial_ms.clone())?;
let client_finished = Prf::alloc_client_finished(
mode,
vm,
outer_partial_ms.clone(),
inner_partial_ms.clone(),
)?;
let server_finished = Prf::alloc_server_finished(
mode,
vm,
outer_partial_ms.clone(),
inner_partial_ms.clone(),
)?;
self.state = State::SessionKeys {
client_random,
server_random,
cf_hash,
sf_hash,
client_random: None,
master_secret,
key_expansion,
client_finished,
server_finished,
};
Ok(PrfOutput { keys, cf_vd, sf_vd })
self.state.prf_output(vm)
}
/// Sets the client random.
///
/// Only the leader can provide the client random.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `client_random` - The client random.
/// * `random` - The client random.
#[instrument(level = "debug", skip_all, err)]
pub fn set_client_random(
&mut self,
vm: &mut dyn Vm<Binary>,
random: Option<[u8; 32]>,
) -> Result<(), PrfError> {
let State::SessionKeys { client_random, .. } = &self.state else {
pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> {
let State::SessionKeys { client_random, .. } = &mut self.state else {
return Err(PrfError::state("PRF not set up"));
};
if self.config.role == Role::Leader {
let Some(random) = random else {
return Err(PrfError::role("leader must provide client random"));
};
vm.assign(*client_random, random).map_err(PrfError::vm)?;
} else if random.is_some() {
return Err(PrfError::role("only leader can set client random"));
}
vm.commit(*client_random).map_err(PrfError::vm)?;
*client_random = Some(random);
Ok(())
}
@@ -236,28 +115,29 @@ impl MpcPrf {
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `server_random` - The server random.
/// * `random` - The server random.
#[instrument(level = "debug", skip_all, err)]
pub fn set_server_random(
&mut self,
vm: &mut dyn Vm<Binary>,
random: [u8; 32],
) -> Result<(), PrfError> {
pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> {
let State::SessionKeys {
server_random,
cf_hash,
sf_hash,
client_random,
master_secret,
key_expansion,
..
} = self.state.take()
} = &mut self.state
else {
return Err(PrfError::state("PRF not set up"));
};
vm.assign(server_random, random).map_err(PrfError::vm)?;
vm.commit(server_random).map_err(PrfError::vm)?;
let client_random = client_random.expect("Client random should have been set by now");
let server_random = random;
self.state = State::ClientFinished { cf_hash, sf_hash };
let mut seed_ms = client_random.to_vec();
seed_ms.extend_from_slice(&server_random);
master_secret.set_start_seed(seed_ms);
let mut seed_ke = server_random.to_vec();
seed_ke.extend_from_slice(&client_random);
key_expansion.set_start_seed(seed_ke);
Ok(())
}
@@ -266,22 +146,18 @@ impl MpcPrf {
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
pub fn set_cf_hash(
&mut self,
vm: &mut dyn Vm<Binary>,
handshake_hash: [u8; 32],
) -> Result<(), PrfError> {
let State::ClientFinished { cf_hash, sf_hash } = self.state.take() else {
pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> {
let State::ClientFinished {
client_finished, ..
} = &mut self.state
else {
return Err(PrfError::state("PRF not in client finished state"));
};
vm.assign(cf_hash, handshake_hash).map_err(PrfError::vm)?;
vm.commit(cf_hash).map_err(PrfError::vm)?;
self.state = State::ServerFinished { sf_hash };
let seed_cf = handshake_hash.to_vec();
client_finished.set_start_seed(seed_cf);
Ok(())
}
@@ -290,23 +166,242 @@ impl MpcPrf {
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
pub fn set_sf_hash(
&mut self,
vm: &mut dyn Vm<Binary>,
handshake_hash: [u8; 32],
) -> Result<(), PrfError> {
let State::ServerFinished { sf_hash } = self.state.take() else {
pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> {
let State::ServerFinished { server_finished } = &mut self.state else {
return Err(PrfError::state("PRF not in server finished state"));
};
vm.assign(sf_hash, handshake_hash).map_err(PrfError::vm)?;
vm.commit(sf_hash).map_err(PrfError::vm)?;
let seed_sf = handshake_hash.to_vec();
server_finished.set_start_seed(seed_sf);
self.state = State::Complete;
Ok(())
}
/// Returns if the PRF needs to be flushed.
pub fn wants_flush(&self) -> bool {
match &self.state {
State::Initialized => false,
State::SessionKeys {
master_secret,
key_expansion,
..
} => master_secret.wants_flush() || key_expansion.wants_flush(),
State::ClientFinished {
client_finished, ..
} => client_finished.wants_flush(),
State::ServerFinished { server_finished } => server_finished.wants_flush(),
State::Complete => false,
State::Error => false,
}
}
/// Flushes the PRF.
pub fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
self.state = match self.state.take() {
State::SessionKeys {
client_random,
mut master_secret,
mut key_expansion,
client_finished,
server_finished,
} => {
master_secret.flush(vm)?;
key_expansion.flush(vm)?;
if !master_secret.wants_flush() && !key_expansion.wants_flush() {
State::ClientFinished {
client_finished,
server_finished,
}
} else {
State::SessionKeys {
client_random,
master_secret,
key_expansion,
client_finished,
server_finished,
}
}
}
State::ClientFinished {
mut client_finished,
server_finished,
} => {
client_finished.flush(vm)?;
if !client_finished.wants_flush() {
State::ServerFinished { server_finished }
} else {
State::ClientFinished {
client_finished,
server_finished,
}
}
}
State::ServerFinished {
mut server_finished,
} => {
server_finished.flush(vm)?;
if !server_finished.wants_flush() {
State::Complete
} else {
State::ServerFinished { server_finished }
}
}
other => other,
};
Ok(())
}
}
/// Depending on the provided `mask` computes and returns `outer_partial` or
/// `inner_partial` for HMAC-SHA256.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `key` - Key to pad and xor.
/// * `mask`- Mask used for padding.
fn compute_partial(
vm: &mut dyn Vm<Binary>,
key: Vector<U8>,
mask: [u8; 64],
) -> Result<Sha256, PrfError> {
let xor = Arc::new(xor(8 * 64));
let additional_len = 64 - key.len();
let padding = vec![0_u8; additional_len];
let padding_ref: Vector<U8> = vm.alloc_vec(additional_len).map_err(PrfError::vm)?;
vm.mark_public(padding_ref).map_err(PrfError::vm)?;
vm.assign(padding_ref, padding).map_err(PrfError::vm)?;
vm.commit(padding_ref).map_err(PrfError::vm)?;
let mask_ref: Array<U8, 64> = vm.alloc().map_err(PrfError::vm)?;
vm.mark_public(mask_ref).map_err(PrfError::vm)?;
vm.assign(mask_ref, mask).map_err(PrfError::vm)?;
vm.commit(mask_ref).map_err(PrfError::vm)?;
let xor = Call::builder(xor)
.arg(key)
.arg(padding_ref)
.arg(mask_ref)
.build()
.map_err(PrfError::vm)?;
let key_padded: Vector<U8> = vm.call(xor).map_err(PrfError::vm)?;
let mut sha = Sha256::new_with_init(vm)?;
sha.update(&key_padded);
sha.compress(vm)?;
Ok(sha)
}
fn merge_outputs(
vm: &mut dyn Vm<Binary>,
inputs: Vec<Array<U8, 32>>,
output_bytes: usize,
) -> Result<Vector<U8>, PrfError> {
assert!(output_bytes <= 32 * inputs.len());
let bits = Array::<U8, 32>::SIZE * inputs.len();
let circ = gen_merge_circ(bits);
let mut builder = Call::builder(circ);
for &input in inputs.iter() {
builder = builder.arg(input);
}
let call = builder.build().map_err(PrfError::vm)?;
let mut output: Vector<U8> = vm.call(call).map_err(PrfError::vm)?;
output.truncate(output_bytes);
Ok(output)
}
fn gen_merge_circ(size: usize) -> Arc<Circuit> {
let mut builder = CircuitBuilder::new();
let inputs = (0..size).map(|_| builder.add_input()).collect::<Vec<_>>();
for input in inputs.chunks_exact(8) {
for byte in input.chunks_exact(8) {
for &feed in byte.iter() {
let output = builder.add_id_gate(feed);
builder.add_output(output);
}
}
}
Arc::new(builder.build().expect("merge circuit is valid"))
}
#[cfg(test)]
mod tests {
use crate::{prf::merge_outputs, test_utils::mock_vm};
use mpz_common::context::test_st_context;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
};
#[tokio::test]
async fn test_merge_outputs() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let input1: [u8; 32] = std::array::from_fn(|i| i as u8);
let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32);
let mut expected = input1.to_vec();
expected.extend_from_slice(&input2);
expected.truncate(48);
// leader
let input1_leader: Array<U8, 32> = leader.alloc().unwrap();
let input2_leader: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(input1_leader).unwrap();
leader.mark_public(input2_leader).unwrap();
leader.assign(input1_leader, input1).unwrap();
leader.assign(input2_leader, input2).unwrap();
leader.commit(input1_leader).unwrap();
leader.commit(input2_leader).unwrap();
let merged_leader =
merge_outputs(&mut leader, vec![input1_leader, input2_leader], 48).unwrap();
let mut merged_leader = leader.decode(merged_leader).unwrap();
// follower
let input1_follower: Array<U8, 32> = follower.alloc().unwrap();
let input2_follower: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(input1_follower).unwrap();
follower.mark_public(input2_follower).unwrap();
follower.assign(input1_follower, input1).unwrap();
follower.assign(input2_follower, input2).unwrap();
follower.commit(input1_follower).unwrap();
follower.commit(input2_follower).unwrap();
let merged_follower =
merge_outputs(&mut follower, vec![input1_follower, input2_follower], 48).unwrap();
let mut merged_follower = follower.decode(merged_follower).unwrap();
tokio::try_join!(
leader.execute_all(&mut ctx_a),
follower.execute_all(&mut ctx_b)
)
.unwrap();
let merged_leader = merged_leader.try_recv().unwrap().unwrap();
let merged_follower = merged_follower.try_recv().unwrap().unwrap();
assert_eq!(merged_leader, merged_follower);
assert_eq!(merged_leader, expected);
}
}

View File

@@ -0,0 +1,257 @@
//! Provides [`Prf`], for computing the TLS 1.2 PRF.
use crate::{Mode, PrfError};
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array,
},
Vm,
};
mod normal;
mod reduced;
#[derive(Debug)]
pub(crate) enum Prf {
Reduced(reduced::PrfFunction),
Normal(normal::PrfFunction),
}
impl Prf {
pub(crate) fn alloc_master_secret(
mode: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
let prf = match mode {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_master_secret(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_master_secret(
vm,
outer_partial,
inner_partial,
)?),
};
Ok(prf)
}
pub(crate) fn alloc_key_expansion(
mode: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
let prf = match mode {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_key_expansion(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_key_expansion(
vm,
outer_partial,
inner_partial,
)?),
};
Ok(prf)
}
pub(crate) fn alloc_client_finished(
config: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
let prf = match config {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_client_finished(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_client_finished(
vm,
outer_partial,
inner_partial,
)?),
};
Ok(prf)
}
pub(crate) fn alloc_server_finished(
config: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
let prf = match config {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_server_finished(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_server_finished(
vm,
outer_partial,
inner_partial,
)?),
};
Ok(prf)
}
pub(crate) fn wants_flush(&self) -> bool {
match self {
Prf::Reduced(prf) => prf.wants_flush(),
Prf::Normal(prf) => prf.wants_flush(),
}
}
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
match self {
Prf::Reduced(prf) => prf.flush(vm),
Prf::Normal(prf) => prf.flush(vm),
}
}
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
match self {
Prf::Reduced(prf) => prf.set_start_seed(seed),
Prf::Normal(prf) => prf.set_start_seed(seed),
}
}
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
match self {
Prf::Reduced(prf) => prf.output(),
Prf::Normal(prf) => prf.output(),
}
}
}
#[cfg(test)]
mod tests {
use crate::{
prf::{compute_partial, function::Prf},
test_utils::{mock_vm, phash},
Mode,
};
use mpz_common::context::test_st_context;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
};
use rand::{rngs::ThreadRng, Rng};
const IPAD: [u8; 64] = [0x36; 64];
const OPAD: [u8; 64] = [0x5c; 64];
#[tokio::test]
async fn test_phash_reduced() {
let mode = Mode::Reduced;
test_phash(mode).await;
}
#[tokio::test]
async fn test_phash_normal() {
let mode = Mode::Normal;
test_phash(mode).await;
}
async fn test_phash(mode: Mode) {
let mut rng = ThreadRng::default();
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let key: [u8; 32] = rng.random();
let start_seed: Vec<u8> = vec![42; 64];
let mut label_seed = b"master secret".to_vec();
label_seed.extend_from_slice(&start_seed);
let iterations = 2;
let leader_key: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(leader_key).unwrap();
leader.assign(leader_key, key).unwrap();
leader.commit(leader_key).unwrap();
let outer_partial_leader = compute_partial(&mut leader, leader_key.into(), OPAD).unwrap();
let inner_partial_leader = compute_partial(&mut leader, leader_key.into(), IPAD).unwrap();
let mut prf_leader = Prf::alloc_master_secret(
mode,
&mut leader,
outer_partial_leader,
inner_partial_leader,
)
.unwrap();
prf_leader.set_start_seed(start_seed.clone());
let mut prf_out_leader = vec![];
for p in prf_leader.output() {
let p_out = leader.decode(p).unwrap();
prf_out_leader.push(p_out)
}
let follower_key: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(follower_key).unwrap();
follower.assign(follower_key, key).unwrap();
follower.commit(follower_key).unwrap();
let outer_partial_follower =
compute_partial(&mut follower, follower_key.into(), OPAD).unwrap();
let inner_partial_follower =
compute_partial(&mut follower, follower_key.into(), IPAD).unwrap();
let mut prf_follower = Prf::alloc_master_secret(
mode,
&mut follower,
outer_partial_follower,
inner_partial_follower,
)
.unwrap();
prf_follower.set_start_seed(start_seed.clone());
let mut prf_out_follower = vec![];
for p in prf_follower.output() {
let p_out = follower.decode(p).unwrap();
prf_out_follower.push(p_out)
}
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
assert_eq!(prf_out_leader.len(), 2);
assert_eq!(prf_out_leader.len(), prf_out_follower.len());
let prf_result_leader: Vec<u8> = prf_out_leader
.iter_mut()
.flat_map(|p| p.try_recv().unwrap().unwrap())
.collect();
let prf_result_follower: Vec<u8> = prf_out_follower
.iter_mut()
.flat_map(|p| p.try_recv().unwrap().unwrap())
.collect();
let expected = phash(key.to_vec(), &label_seed, iterations);
assert_eq!(prf_result_leader, prf_result_follower);
assert_eq!(prf_result_leader, expected)
}
}

View File

@@ -0,0 +1,174 @@
//! Computes the whole PRF in MPC.
use crate::{hmac::hmac_sha256, PrfError};
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, MemoryExt, Vector, ViewExt,
},
Vm,
};
#[derive(Debug)]
pub(crate) struct PrfFunction {
// The label, e.g. "master secret".
label: &'static [u8],
state: State,
// The start seed and the label, e.g. client_random + server_random + "master_secret".
start_seed_label: Option<Vec<u8>>,
a: Vec<PHash>,
p: Vec<PHash>,
}
impl PrfFunction {
const MS_LABEL: &[u8] = b"master secret";
const KEY_LABEL: &[u8] = b"key expansion";
const CF_LABEL: &[u8] = b"client finished";
const SF_LABEL: &[u8] = b"server finished";
pub(crate) fn alloc_master_secret(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48, 64)
}
pub(crate) fn alloc_key_expansion(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40, 64)
}
pub(crate) fn alloc_client_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12, 32)
}
pub(crate) fn alloc_server_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32)
}
pub(crate) fn wants_flush(&self) -> bool {
let is_computing = match self.state {
State::Computing => true,
State::Finished => false,
};
is_computing && self.start_seed_label.is_some()
}
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
if let State::Computing = self.state {
let a = self.a.first().expect("prf should be allocated");
let msg = *a.msg.first().expect("message for prf should be present");
let msg_value = self
.start_seed_label
.clone()
.expect("Start seed should have been set");
vm.assign(msg, msg_value).map_err(PrfError::vm)?;
vm.commit(msg).map_err(PrfError::vm)?;
self.state = State::Finished;
}
Ok(())
}
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
let mut start_seed_label = self.label.to_vec();
start_seed_label.extend_from_slice(&seed);
self.start_seed_label = Some(start_seed_label);
}
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
self.p.iter().map(|p| p.output).collect()
}
fn alloc(
vm: &mut dyn Vm<Binary>,
label: &'static [u8],
outer_partial: Sha256,
inner_partial: Sha256,
output_len: usize,
seed_len: usize,
) -> Result<Self, PrfError> {
let mut prf = Self {
label,
state: State::Computing,
start_seed_label: None,
a: vec![],
p: vec![],
};
assert!(output_len > 0, "cannot compute 0 bytes for prf");
let iterations = output_len.div_ceil(32);
let msg_len_a = label.len() + seed_len;
let seed_label_ref: Vector<U8> = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?;
vm.mark_public(seed_label_ref).map_err(PrfError::vm)?;
let mut msg_a = seed_label_ref;
for _ in 0..iterations {
let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), &[msg_a])?;
msg_a = Vector::<U8>::from(a.output);
prf.a.push(a);
let p = PHash::alloc(
vm,
outer_partial.clone(),
inner_partial.clone(),
&[msg_a, seed_label_ref],
)?;
prf.p.push(p);
}
Ok(prf)
}
}
#[derive(Debug, Clone, Copy)]
enum State {
Computing,
Finished,
}
#[derive(Debug, Clone)]
struct PHash {
msg: Vec<Vector<U8>>,
output: Array<U8, 32>,
}
impl PHash {
fn alloc(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
msg: &[Vector<U8>],
) -> Result<Self, PrfError> {
let mut inner_local = inner_partial;
msg.iter().for_each(|m| inner_local.update(m));
inner_local.compress(vm)?;
let inner_local = inner_local.finalize(vm)?;
let output = hmac_sha256(vm, outer_partial, inner_local)?;
let p_hash = Self {
msg: msg.to_vec(),
output,
};
Ok(p_hash)
}
}

View File

@@ -0,0 +1,247 @@
//! Computes some hashes of the PRF locally.
use std::collections::VecDeque;
use crate::{hmac::hmac_sha256, sha256, state_to_bytes, PrfError};
use mpz_core::bitvec::BitVec;
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, DecodeFutureTyped, MemoryExt, ViewExt,
},
Vm,
};
#[derive(Debug)]
pub(crate) struct PrfFunction {
// The label, e.g. "master secret".
label: &'static [u8],
// The start seed and the label, e.g. client_random + server_random + "master_secret".
start_seed_label: Option<Vec<u8>>,
iterations: usize,
state: PrfState,
a: VecDeque<AHash>,
p: VecDeque<PHash>,
}
#[derive(Debug)]
enum PrfState {
InnerPartial {
inner_partial: DecodeFutureTyped<BitVec, [u32; 8]>,
},
ComputeA {
iter: usize,
inner_partial: [u32; 8],
msg: Vec<u8>,
},
ComputeP {
iter: usize,
inner_partial: [u32; 8],
a_output: DecodeFutureTyped<BitVec, [u8; 32]>,
},
FinishLastP,
Done,
}
impl PrfFunction {
const MS_LABEL: &[u8] = b"master secret";
const KEY_LABEL: &[u8] = b"key expansion";
const CF_LABEL: &[u8] = b"client finished";
const SF_LABEL: &[u8] = b"server finished";
pub(crate) fn alloc_master_secret(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48)
}
pub(crate) fn alloc_key_expansion(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40)
}
pub(crate) fn alloc_client_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12)
}
pub(crate) fn alloc_server_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12)
}
pub(crate) fn wants_flush(&self) -> bool {
!matches!(self.state, PrfState::Done) && self.start_seed_label.is_some()
}
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
match &mut self.state {
PrfState::InnerPartial { inner_partial } => {
let Some(inner_partial) = inner_partial.try_recv().map_err(PrfError::vm)? else {
return Ok(());
};
self.state = PrfState::ComputeA {
iter: 1,
inner_partial,
msg: self
.start_seed_label
.clone()
.expect("Start seed should have been set"),
};
self.flush(vm)?;
}
PrfState::ComputeA {
iter,
inner_partial,
msg,
} => {
let a = self.a.pop_front().expect("Prf AHash should be present");
assign_inner_local(vm, a.inner_local, *inner_partial, msg)?;
self.state = PrfState::ComputeP {
iter: *iter,
inner_partial: *inner_partial,
a_output: a.output,
};
}
PrfState::ComputeP {
iter,
inner_partial,
a_output,
} => {
let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else {
return Ok(());
};
let p = self.p.pop_front().expect("Prf PHash should be present");
let mut msg = output.to_vec();
msg.extend_from_slice(
self.start_seed_label
.as_ref()
.expect("Start seed should have been set"),
);
assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?;
if *iter == self.iterations {
self.state = PrfState::FinishLastP;
} else {
self.state = PrfState::ComputeA {
iter: *iter + 1,
inner_partial: *inner_partial,
msg: output.to_vec(),
}
};
}
PrfState::FinishLastP => self.state = PrfState::Done,
_ => (),
}
Ok(())
}
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
let mut start_seed_label = self.label.to_vec();
start_seed_label.extend_from_slice(&seed);
self.start_seed_label = Some(start_seed_label);
}
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
self.p.iter().map(|p| p.output).collect()
}
fn alloc(
vm: &mut dyn Vm<Binary>,
label: &'static [u8],
outer_partial: Sha256,
inner_partial: Sha256,
len: usize,
) -> Result<Self, PrfError> {
assert!(len > 0, "cannot compute 0 bytes for prf");
let iterations = len.div_ceil(32);
let (inner_partial, _) = inner_partial
.state()
.expect("state should be set for inner_partial");
let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?;
let mut prf = Self {
label,
start_seed_label: None,
iterations,
state: PrfState::InnerPartial { inner_partial },
a: VecDeque::new(),
p: VecDeque::new(),
};
for _ in 0..iterations {
// setup A[i]
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
let output = vm.decode(output).map_err(PrfError::vm)?;
let a_hash = AHash {
inner_local,
output,
};
prf.a.push_front(a_hash);
// setup P[i]
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
let p_hash = PHash {
inner_local,
output,
};
prf.p.push_front(p_hash);
}
Ok(prf)
}
}
fn assign_inner_local(
vm: &mut dyn Vm<Binary>,
inner_local: Array<U8, 32>,
inner_partial: [u32; 8],
msg: &[u8],
) -> Result<(), PrfError> {
let inner_local_value = sha256(inner_partial, 64, msg);
vm.mark_public(inner_local).map_err(PrfError::vm)?;
vm.assign(inner_local, state_to_bytes(inner_local_value))
.map_err(PrfError::vm)?;
vm.commit(inner_local).map_err(PrfError::vm)?;
Ok(())
}
/// Like PHash but stores the output as the decoding future because in the
/// reduced Prf we need to decode this output.
#[derive(Debug)]
struct AHash {
inner_local: Array<U8, 32>,
output: DecodeFutureTyped<BitVec, [u8; 32]>,
}
#[derive(Debug, Clone, Copy)]
struct PHash {
inner_local: Array<U8, 32>,
output: Array<U8, 32>,
}

View File

@@ -0,0 +1,103 @@
use crate::{
prf::{function::Prf, merge_outputs},
PrfError, PrfOutput, SessionKeys,
};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, FromRaw, ToRaw,
},
Vm,
};
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Initialized,
SessionKeys {
client_random: Option<[u8; 32]>,
master_secret: Prf,
key_expansion: Prf,
client_finished: Prf,
server_finished: Prf,
},
ClientFinished {
client_finished: Prf,
server_finished: Prf,
},
ServerFinished {
server_finished: Prf,
},
Complete,
Error,
}
impl State {
pub(crate) fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
pub(crate) fn prf_output(&self, vm: &mut dyn Vm<Binary>) -> Result<PrfOutput, PrfError> {
let State::SessionKeys {
key_expansion,
client_finished,
server_finished,
..
} = self
else {
return Err(PrfError::state(
"Prf output can only be computed while in \"SessionKeys\" state",
));
};
let keys = get_session_keys(key_expansion.output(), vm)?;
let cf_vd = get_client_finished_vd(client_finished.output(), vm)?;
let sf_vd = get_server_finished_vd(server_finished.output(), vm)?;
let output = PrfOutput { keys, cf_vd, sf_vd };
Ok(output)
}
}
fn get_session_keys(
output: Vec<Array<U8, 32>>,
vm: &mut dyn Vm<Binary>,
) -> Result<SessionKeys, PrfError> {
let mut keys = merge_outputs(vm, output, 40)?;
debug_assert!(keys.len() == 40, "session keys len should be 40");
let server_iv = Array::<U8, 4>::try_from(keys.split_off(36)).unwrap();
let client_iv = Array::<U8, 4>::try_from(keys.split_off(32)).unwrap();
let server_write_key = Array::<U8, 16>::try_from(keys.split_off(16)).unwrap();
let client_write_key = Array::<U8, 16>::try_from(keys).unwrap();
let session_keys = SessionKeys {
client_write_key,
server_write_key,
client_iv,
server_iv,
};
Ok(session_keys)
}
fn get_client_finished_vd(
output: Vec<Array<U8, 32>>,
vm: &mut dyn Vm<Binary>,
) -> Result<Array<U8, 12>, PrfError> {
let cf_vd = merge_outputs(vm, output, 12)?;
let cf_vd = <Array<U8, 12> as FromRaw<Binary>>::from_raw(cf_vd.to_raw());
Ok(cf_vd)
}
fn get_server_finished_vd(
output: Vec<Array<U8, 32>>,
vm: &mut dyn Vm<Binary>,
) -> Result<Array<U8, 12>, PrfError> {
let sf_vd = merge_outputs(vm, output, 12)?;
let sf_vd = <Array<U8, 12> as FromRaw<Binary>>::from_raw(sf_vd.to_raw());
Ok(sf_vd)
}

View File

@@ -0,0 +1,261 @@
use crate::{sha256, state_to_bytes};
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
use mpz_vm_core::memory::correlated::Delta;
use rand::{rngs::StdRng, Rng, SeedableRng};
pub(crate) const SHA256_IV: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
];
pub(crate) fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
let mut label_start_seed = b"master secret".to_vec();
label_start_seed.extend_from_slice(&client_random);
label_start_seed.extend_from_slice(&server_random);
let ms = phash(pms.to_vec(), &label_start_seed, 2)[..48].to_vec();
ms.try_into().unwrap()
}
pub(crate) fn prf_keys(
ms: [u8; 48],
client_random: [u8; 32],
server_random: [u8; 32],
) -> [Vec<u8>; 4] {
let mut label_start_seed = b"key expansion".to_vec();
label_start_seed.extend_from_slice(&server_random);
label_start_seed.extend_from_slice(&client_random);
let mut session_keys = phash(ms.to_vec(), &label_start_seed, 2)[..40].to_vec();
let server_iv = session_keys.split_off(36);
let client_iv = session_keys.split_off(32);
let server_write_key = session_keys.split_off(16);
let client_write_key = session_keys;
[client_write_key, server_write_key, client_iv, server_iv]
}
pub(crate) fn prf_cf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec<u8> {
let mut label_start_seed = b"client finished".to_vec();
label_start_seed.extend_from_slice(&hanshake_hash);
phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec()
}
pub(crate) fn prf_sf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec<u8> {
let mut label_start_seed = b"server finished".to_vec();
label_start_seed.extend_from_slice(&hanshake_hash);
phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec()
}
pub(crate) fn phash(key: Vec<u8>, seed: &[u8], iterations: usize) -> Vec<u8> {
// A() is defined as:
//
// A(0) = seed
// A(i) = HMAC_hash(secret, A(i-1))
let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1);
a_cache.push(seed.to_vec());
for i in 0..iterations {
let a_i = hmac_sha256(key.clone(), &a_cache[i]);
a_cache.push(a_i.to_vec());
}
// HMAC_hash(secret, A(i) + seed)
let mut output: Vec<_> = Vec::with_capacity(iterations * 32);
for i in 0..iterations {
let mut a_i_seed = a_cache[i + 1].clone();
a_i_seed.extend_from_slice(seed);
let hash = hmac_sha256(key.clone(), &a_i_seed);
output.extend_from_slice(&hash);
}
output
}
pub(crate) fn hmac_sha256(key: Vec<u8>, msg: &[u8]) -> [u8; 32] {
let outer_partial = compute_outer_partial(key.clone());
let inner_local = compute_inner_local(key, msg);
let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local));
state_to_bytes(hmac)
}
pub(crate) fn compute_outer_partial(mut key: Vec<u8>) -> [u32; 8] {
assert!(key.len() <= 64);
key.resize(64, 0_u8);
let key_padded: [u8; 64] = key
.into_iter()
.map(|b| b ^ 0x5c)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
compress_256(SHA256_IV, &key_padded)
}
pub(crate) fn compute_inner_local(mut key: Vec<u8>, msg: &[u8]) -> [u32; 8] {
assert!(key.len() <= 64);
key.resize(64, 0_u8);
let key_padded: [u8; 64] = key
.into_iter()
.map(|b| b ^ 0x36)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
let state = compress_256(SHA256_IV, &key_padded);
sha256(state, 64, msg)
}
pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] {
use sha2::{
compress256,
digest::{
block_buffer::{BlockBuffer, Eager},
generic_array::typenum::U64,
},
};
let mut buffer = BlockBuffer::<U64, Eager>::default();
buffer.digest_blocks(msg, |b| compress256(&mut state, b));
state
}
// Borrowed from Rustls for testing
// https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs
mod ring_prf {
use ring::{hmac, hmac::HMAC_SHA256};
fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag {
let mut ctx = hmac::Context::with_key(key);
ctx.update(a);
ctx.update(b);
ctx.sign()
}
fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) {
let hmac_key = hmac::Key::new(HMAC_SHA256, secret);
// A(1)
let mut current_a = hmac::sign(&hmac_key, seed);
let chunk_size = HMAC_SHA256.digest_algorithm().output_len();
for chunk in out.chunks_mut(chunk_size) {
// P_hash[i] = HMAC_hash(secret, A(i) + seed)
let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed);
chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]);
// A(i+1) = HMAC_hash(secret, A(i))
current_a = hmac::sign(&hmac_key, current_a.as_ref());
}
}
fn concat(a: &[u8], b: &[u8]) -> Vec<u8> {
let mut ret = Vec::new();
ret.extend_from_slice(a);
ret.extend_from_slice(b);
ret
}
pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) {
let joined_seed = concat(label, seed);
p(out, secret, &joined_seed);
}
}
#[test]
fn test_prf_reference_ms() {
use ring_prf::prf as prf_ref;
let mut rng = StdRng::from_seed([1; 32]);
let pms: [u8; 32] = rng.random();
let label: &[u8] = b"master secret";
let client_random: [u8; 32] = rng.random();
let server_random: [u8; 32] = rng.random();
let mut seed = Vec::from(client_random);
seed.extend_from_slice(&server_random);
let ms = prf_ms(pms, client_random, server_random);
let mut expected_ms: [u8; 48] = [0; 48];
prf_ref(&mut expected_ms, &pms, label, &seed);
assert_eq!(ms, expected_ms);
}
#[test]
fn test_prf_reference_ke() {
use ring_prf::prf as prf_ref;
let mut rng = StdRng::from_seed([2; 32]);
let ms: [u8; 48] = rng.random();
let label: &[u8] = b"key expansion";
let client_random: [u8; 32] = rng.random();
let server_random: [u8; 32] = rng.random();
let mut seed = Vec::from(server_random);
seed.extend_from_slice(&client_random);
let keys = prf_keys(ms, client_random, server_random);
let keys: Vec<u8> = keys.into_iter().flatten().collect();
let mut expected_keys: [u8; 40] = [0; 40];
prf_ref(&mut expected_keys, &ms, label, &seed);
assert_eq!(keys, expected_keys);
}
#[test]
fn test_prf_reference_cf() {
use ring_prf::prf as prf_ref;
let mut rng = StdRng::from_seed([3; 32]);
let ms: [u8; 48] = rng.random();
let label: &[u8] = b"client finished";
let handshake_hash: [u8; 32] = rng.random();
let cf_vd = prf_cf_vd(ms, handshake_hash);
let mut expected_cf_vd: [u8; 12] = [0; 12];
prf_ref(&mut expected_cf_vd, &ms, label, &handshake_hash);
assert_eq!(cf_vd, expected_cf_vd);
}
#[test]
fn test_prf_reference_sf() {
use ring_prf::prf as prf_ref;
let mut rng = StdRng::from_seed([4; 32]);
let ms: [u8; 48] = rng.random();
let label: &[u8] = b"server finished";
let handshake_hash: [u8; 32] = rng.random();
let sf_vd = prf_sf_vd(ms, handshake_hash);
let mut expected_sf_vd: [u8; 12] = [0; 12];
prf_ref(&mut expected_sf_vd, &ms, label, &handshake_hash);
assert_eq!(sf_vd, expected_sf_vd);
}

View File

@@ -20,7 +20,7 @@ default = []
[dependencies]
tlsn-cipher = { workspace = true }
tlsn-common = { workspace = true }
#tlsn-hmac-sha256 = { workspace = true }
tlsn-hmac-sha256 = { workspace = true }
tlsn-key-exchange = { workspace = true }
tlsn-tls-backend = { workspace = true }
tlsn-tls-core = { workspace = true, features = ["serde"] }

View File

@@ -1,4 +1,5 @@
use derive_builder::Builder;
use hmac_sha256::Mode as PrfMode;
/// Number of TLS protocol bytes that will be sent.
const PROTOCOL_DATA_SENT: usize = 32;
@@ -55,6 +56,9 @@ pub struct Config {
/// Maximum number of received bytes.
#[allow(unused)]
pub(crate) max_recv: usize,
/// Configuration options for the PRF.
#[builder(setter(custom))]
pub(crate) prf: PrfMode,
}
impl Config {
@@ -65,6 +69,12 @@ impl Config {
}
impl ConfigBuilder {
/// Optimizes the protocol for low bandwidth networks.
pub fn low_bandwidth(&mut self) -> &mut Self {
self.prf = Some(PrfMode::Reduced);
self
}
/// Builds the configuration.
pub fn build(&self) -> Result<Config, ConfigBuilderError> {
let defer_decryption = self.defer_decryption.unwrap_or(true);
@@ -95,6 +105,8 @@ impl ConfigBuilder {
.max_recv_records
.unwrap_or_else(|| PROTOCOL_RECORD_COUNT_RECV + default_record_count(max_recv));
let prf = self.prf.unwrap_or_default();
Ok(Config {
defer_decryption,
max_sent_records,
@@ -102,6 +114,7 @@ impl ConfigBuilder {
max_recv_records,
max_recv_online,
max_recv,
prf,
})
}
}

View File

@@ -3,7 +3,7 @@ use crate::{
record_layer::{aead::MpcAesGcm, RecordLayer},
Config, FollowerData, MpcTlsError, Role, SessionKeys, Vm,
};
use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput};
use hmac_sha256::{MpcPrf, PrfOutput};
use ke::KeyExchange;
use key_exchange::{self as ke, MpcKeyExchange};
use mpz_common::{Context, Flush};
@@ -63,12 +63,7 @@ impl MpcTlsFollower {
)),
)) as Box<dyn KeyExchange + Send + Sync>;
let prf = MpcPrf::new(
PrfConfig::builder()
.role(hmac_sha256::Role::Follower)
.build()
.expect("PRF config is valid"),
);
let prf = MpcPrf::new(config.prf);
let encrypter = MpcAesGcm::new(
ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new(
@@ -123,8 +118,6 @@ impl MpcTlsFollower {
keys.server_iv,
)?;
prf.set_client_random(vm, None)?;
let cf_vd = vm.decode(cf_vd).map_err(MpcTlsError::alloc)?;
let sf_vd = vm.decode(sf_vd).map_err(MpcTlsError::alloc)?;
@@ -230,6 +223,7 @@ impl MpcTlsFollower {
return Err(MpcTlsError::state("must be in ready state to run"));
};
let mut client_random = None;
let mut server_random = None;
let mut server_key = None;
let mut cf_vd = None;
@@ -237,17 +231,20 @@ impl MpcTlsFollower {
loop {
let msg: Message = self.ctx.io_mut().expect_next().await?;
match msg {
Message::SetClientRandom(random) => {
if client_random.is_some() {
return Err(MpcTlsError::hs("client random already set"));
}
prf.set_client_random(random.random)?;
client_random = Some(random);
}
Message::SetServerRandom(random) => {
if server_random.is_some() {
return Err(MpcTlsError::hs("server random already set"));
}
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_server_random(&mut (*vm), random.random)?;
prf.set_server_random(random.random)?;
server_random = Some(random);
}
Message::SetServerKey(key) => {
@@ -274,9 +271,12 @@ impl MpcTlsFollower {
ke.compute_shares(&mut self.ctx).await?;
ke.assign(&mut (*vm))?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
while prf.wants_flush() {
prf.flush(&mut *vm)?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
}
ke.finalize().await?;
record_layer.setup(&mut self.ctx).await?;
@@ -290,11 +290,14 @@ impl MpcTlsFollower {
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_cf_hash(&mut (*vm), vd.handshake_hash)?;
prf.set_cf_hash(vd.handshake_hash)?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
while prf.wants_flush() {
prf.flush(&mut *vm)?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
}
cf_vd = Some(
cf_vd_fut
@@ -312,11 +315,14 @@ impl MpcTlsFollower {
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_sf_hash(&mut (*vm), vd.handshake_hash)?;
prf.set_sf_hash(vd.handshake_hash)?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
while prf.wants_flush() {
prf.flush(&mut *vm)?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
}
sf_vd = Some(
sf_vd_fut

View File

@@ -3,15 +3,15 @@ mod actor;
use crate::{
error::MpcTlsError,
msg::{
ClientFinishedVd, Decrypt, Encrypt, Message, ServerFinishedVd, SetServerKey,
SetServerRandom,
ClientFinishedVd, Decrypt, Encrypt, Message, ServerFinishedVd, SetClientRandom,
SetServerKey, SetServerRandom,
},
record_layer::{aead::MpcAesGcm, DecryptMode, EncryptMode, RecordLayer},
utils::opaque_into_parts,
Config, LeaderOutput, Role, SessionKeys, Vm,
};
use async_trait::async_trait;
use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput};
use hmac_sha256::{MpcPrf, PrfOutput};
use ke::KeyExchange;
use key_exchange::{self as ke, MpcKeyExchange};
use ludi::Context as LudiContext;
@@ -87,12 +87,7 @@ impl MpcTlsLeader {
))),
)) as Box<dyn KeyExchange + Send + Sync>;
let prf = MpcPrf::new(
PrfConfig::builder()
.role(hmac_sha256::Role::Leader)
.build()
.expect("prf config is valid"),
);
let prf = MpcPrf::new(config.prf);
let encrypter = MpcAesGcm::new(
ShareConversionSender::new(OLESender::new(
@@ -157,8 +152,6 @@ impl MpcTlsLeader {
keys.server_iv,
)?;
prf.set_client_random(&mut (*vm_lock), Some(client_random.0))?;
let cf_vd = vm_lock.decode(cf_vd).map_err(MpcTlsError::alloc)?;
let sf_vd = vm_lock.decode(sf_vd).map_err(MpcTlsError::alloc)?;
@@ -194,7 +187,7 @@ impl MpcTlsLeader {
vm,
keys,
mut ke,
prf,
mut prf,
mut record_layer,
cf_vd,
sf_vd,
@@ -238,6 +231,15 @@ impl MpcTlsLeader {
.await
.map_err(MpcTlsError::preprocess)??;
ctx.io_mut()
.send(Message::SetClientRandom(SetClientRandom {
random: client_random.0,
}))
.await
.map_err(MpcTlsError::from)?;
prf.set_client_random(client_random.0)?;
self.state = State::Handshake {
ctx,
vm,
@@ -408,7 +410,6 @@ impl Backend for MpcTlsLeader {
async fn set_server_random(&mut self, random: Random) -> Result<(), BackendError> {
let State::Handshake {
ctx,
vm,
prf,
server_random,
..
@@ -426,13 +427,7 @@ impl Backend for MpcTlsLeader {
.await
.map_err(MpcTlsError::from)?;
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_server_random(&mut (*vm), random.0)
.map_err(MpcTlsError::hs)?;
prf.set_server_random(random.0).map_err(MpcTlsError::hs)?;
*server_random = Some(random);
Ok(())
@@ -543,9 +538,12 @@ impl Backend for MpcTlsLeader {
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_sf_hash(&mut (*vm), hash).map_err(MpcTlsError::hs)?;
prf.set_sf_hash(hash).map_err(MpcTlsError::hs)?;
vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?;
while prf.wants_flush() {
prf.flush(&mut *vm).map_err(MpcTlsError::hs)?;
vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?;
}
let sf_vd = sf_vd
.try_recv()
@@ -586,9 +584,12 @@ impl Backend for MpcTlsLeader {
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::hs("VM lock is held"))?;
prf.set_cf_hash(&mut (*vm), hash).map_err(MpcTlsError::hs)?;
prf.set_cf_hash(hash).map_err(MpcTlsError::hs)?;
vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?;
while prf.wants_flush() {
prf.flush(&mut *vm).map_err(MpcTlsError::hs)?;
vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?;
}
let cf_vd = cf_vd
.try_recv()
@@ -605,7 +606,7 @@ impl Backend for MpcTlsLeader {
vm,
keys,
mut ke,
prf,
mut prf,
mut record_layer,
cf_vd,
sf_vd,
@@ -650,10 +651,15 @@ impl Backend for MpcTlsLeader {
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
ke.assign(&mut (*vm_lock)).map_err(MpcTlsError::hs)?;
vm_lock
.execute_all(&mut ctx)
.await
.map_err(MpcTlsError::hs)?;
while prf.wants_flush() {
prf.flush(&mut *vm_lock).map_err(MpcTlsError::hs)?;
vm_lock
.execute_all(&mut ctx)
.await
.map_err(MpcTlsError::hs)?;
}
ke.finalize().await.map_err(MpcTlsError::hs)?;
record_layer.setup(&mut ctx).await?;
}

View File

@@ -9,6 +9,7 @@ use crate::record_layer::{DecryptMode, EncryptMode};
/// MPC-TLS protocol message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum Message {
SetClientRandom(SetClientRandom),
SetServerRandom(SetServerRandom),
SetServerKey(SetServerKey),
ClientFinishedVd(ClientFinishedVd),
@@ -20,6 +21,11 @@ pub(crate) enum Message {
CloseConnection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct SetClientRandom {
pub(crate) random: [u8; 32],
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct SetServerRandom {
pub(crate) random: [u8; 32],

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use mpc_tls::Config;
use tlsn_common::config::ProtocolConfig;
use tlsn_common::config::{NetworkSetting, ProtocolConfig};
use tlsn_core::{connection::ServerName, CryptoProvider};
/// Configuration for the prover
@@ -55,6 +55,10 @@ impl ProverConfig {
builder.max_recv_records(max_recv_records);
}
if let NetworkSetting::Latency = self.protocol_config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
}

View File

@@ -17,7 +17,6 @@ pub use future::ProverFuture;
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use rand06_compat::Rand0_6CompatExt;
use state::{Notarize, Prove};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};

View File

@@ -4,7 +4,7 @@ use std::{
};
use mpc_tls::Config;
use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator};
use tlsn_common::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator};
use tlsn_core::CryptoProvider;
/// Configuration for the [`Verifier`](crate::tls::Verifier).
@@ -58,6 +58,10 @@ impl VerifierConfig {
builder.max_recv_records(max_recv_records);
}
if let NetworkSetting::Latency = protocol_config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
}

View File

@@ -20,7 +20,6 @@ use mpc_tls::{FollowerData, MpcTlsFollower};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use rand06_compat::Rand0_6CompatExt;
use serio::stream::IoStreamExt;
use state::{Notarize, Verify};
use tls_core::msgs::enums::ContentType;

View File

@@ -1,5 +1,5 @@
use serde::Deserialize;
use tlsn_common::config::ProtocolConfig;
use tlsn_common::config::{NetworkSetting, ProtocolConfig};
use tsify_next::Tsify;
#[derive(Debug, Tsify, Deserialize)]
@@ -12,6 +12,7 @@ pub struct ProverConfig {
pub defer_decryption_from_start: Option<bool>,
pub max_sent_records: Option<usize>,
pub max_recv_records: Option<usize>,
pub network: NetworkSetting,
}
impl From<ProverConfig> for tlsn_prover::ProverConfig {
@@ -37,6 +38,7 @@ impl From<ProverConfig> for tlsn_prover::ProverConfig {
builder.defer_decryption_from_start(value);
}
builder.network(value.network);
let protocol_config = builder.build().unwrap();
let mut builder = tlsn_prover::ProverConfig::builder();