feat: prove server mac key (#868)

* feat(mpc-tls): prove server mac key

* remove stray dep

* move mac key into `SessionKeys`

* fix key translation

* remove dangling dep

* move ghash mod to tlsn-common

* fix clippy lints

* treat all recv recs as unauthenticated

* detach zkvm first, then prove

* decrypt with aes_gcm, decode mac key only in zkvm

* encapsulate into `fn verify_tags`; inline mod `zk_aes_ecb`

* handle error

* fix dangling and clippy

* bump Cargo.lock
This commit is contained in:
dan
2025-06-05 16:19:41 +00:00
committed by GitHub
parent 55a26aad77
commit 345d5d45ad
34 changed files with 524 additions and 247 deletions

1
Cargo.lock generated
View File

@@ -7232,6 +7232,7 @@ dependencies = [
"async-trait",
"derive_builder 0.12.0",
"futures 0.3.31",
"ghash 0.5.1",
"mpz-common",
"mpz-core",
"mpz-hash",

View File

@@ -120,6 +120,7 @@ futures = { version = "0.3" }
futures-rustls = { version = "0.26" }
futures-util = { version = "0.3" }
generic-array = { version = "0.14" }
ghash = { version = "0.5" }
hex = { version = "0.4" }
hmac = { version = "0.12" }
http = { version = "1.1" }

View File

@@ -24,6 +24,7 @@ mpz-zk = { workspace = true }
async-trait = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
ghash = { workspace = true }
once_cell = { workspace = true }
opaque-debug = { workspace = true }
rand = { workspace = true }

View File

@@ -8,7 +8,7 @@ use mpz_vm_core::{prelude::*, Vm};
use crate::{
transcript::Record,
zk_aes::{ZkAesCtr, ZkAesCtrError},
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
Role,
};

View File

@@ -34,9 +34,10 @@ pub struct ProtocolConfig {
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of application data records that can be received.
/// Maximum number of received application data records that can be
/// decrypted online, i.e. while the MPC-TLS connection is active.
#[builder(setter(strip_option), default)]
max_recv_records: Option<usize>,
max_recv_records_online: Option<usize>,
/// Whether the `deferred decryption` feature is toggled on from the start
/// of the MPC-TLS connection.
#[builder(default = "true")]
@@ -87,10 +88,10 @@ impl ProtocolConfig {
self.max_recv_data
}
/// Returns the maximum number of application data records that can
/// be received.
pub fn max_recv_records(&self) -> Option<usize> {
self.max_recv_records
/// Returns the maximum number of received application data records that
/// can be decrypted online.
pub fn max_recv_records_online(&self) -> Option<usize> {
self.max_recv_records_online
}
/// Returns whether the `deferred decryption` feature is toggled on from the
@@ -116,9 +117,9 @@ pub struct ProtocolConfigValidator {
max_sent_records: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of application data records that can be received.
/// Maximum number of application data records that can be received online.
#[builder(default = "DEFAULT_RECORDS_LIMIT")]
max_recv_records: usize,
max_recv_records_online: usize,
/// Version that is being run by checker.
#[builder(setter(skip), default = "VERSION.clone()")]
version: Version,
@@ -147,16 +148,16 @@ impl ProtocolConfigValidator {
}
/// Returns the maximum number of application data records that can
/// be received.
pub fn max_recv_records(&self) -> usize {
self.max_recv_records
/// be received online.
pub fn max_recv_records_online(&self) -> usize {
self.max_recv_records_online
}
/// Performs compatibility check of the protocol configuration between
/// prover and verifier.
pub fn validate(&self, config: &ProtocolConfig) -> Result<(), ProtocolConfigError> {
self.check_max_transcript_size(config.max_sent_data, config.max_recv_data)?;
self.check_max_records(config.max_sent_records, config.max_recv_records)?;
self.check_max_records(config.max_sent_records, config.max_recv_records_online)?;
self.check_version(&config.version)?;
Ok(())
}
@@ -187,7 +188,7 @@ impl ProtocolConfigValidator {
fn check_max_records(
&self,
max_sent_records: Option<usize>,
max_recv_records: Option<usize>,
max_recv_records_online: Option<usize>,
) -> Result<(), ProtocolConfigError> {
if let Some(max_sent_records) = max_sent_records {
if max_sent_records > self.max_sent_records {
@@ -198,11 +199,11 @@ impl ProtocolConfigValidator {
}
}
if let Some(max_recv_records) = max_recv_records {
if max_recv_records > self.max_recv_records {
if let Some(max_recv_records_online) = max_recv_records_online {
if max_recv_records_online > self.max_recv_records_online {
return Err(ProtocolConfigError::max_record_count(format!(
"max_recv_records {} is greater than the configured limit {}",
max_recv_records, self.max_recv_records,
"max_recv_records_online {} is greater than the configured limit {}",
max_recv_records_online, self.max_recv_records_online,
)));
}
}

View File

@@ -0,0 +1,39 @@
//! GHASH methods.
// This module belongs in tls/core. It was moved out here temporarily.
use ghash::{
universal_hash::{KeyInit, UniversalHash as UniversalHashReference},
GHash,
};
/// Computes a GHASH tag.
pub fn ghash(aad: &[u8], ciphertext: &[u8], key: &[u8; 16]) -> [u8; 16] {
let mut ghash = GHash::new(key.into());
ghash.update_padded(&build_ghash_data(aad.to_vec(), ciphertext.to_owned()));
let out = ghash.finalize();
out.into()
}
/// Builds padded data for GHASH.
pub fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
let associated_data_bitlen = (aad.len() as u64) * 8;
let text_bitlen = (ciphertext.len() as u64) * 8;
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);
data.extend(aad);
data.extend(ciphertext);
data.extend_from_slice(&len_block.to_be_bytes());
data
}

View File

@@ -8,10 +8,12 @@ pub mod commit;
pub mod config;
pub mod context;
pub mod encoding;
pub mod ghash;
pub mod msg;
pub mod mux;
pub mod tag;
pub mod transcript;
pub mod zk_aes;
pub mod zk_aes_ctr;
/// The party's role in the TLSN protocol.
///

157
crates/common/src/tag.rs Normal file
View File

@@ -0,0 +1,157 @@
//! TLS record tag verification.
use crate::{ghash::ghash, transcript::Record};
use cipher::{aes::Aes128, Cipher};
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
binary::{Binary, U8},
DecodeFutureTyped,
};
use mpz_vm_core::{prelude::*, Vm};
use tls_core::cipher::make_tls12_aad;
/// Proves the verification of tags of the given `records`,
/// returning a proof.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `key_iv` - Cipher key and IV.
/// * `mac_key` - MAC key.
/// * `records` - Records for which the verification is to be proven.
pub fn verify_tags(
vm: &mut dyn Vm<Binary>,
key_iv: (Array<U8, 16>, Array<U8, 4>),
mac_key: Array<U8, 16>,
records: Vec<Record>,
) -> Result<TagProof, TagProofError> {
let mut aes = Aes128::default();
aes.set_key(key_iv.0);
aes.set_iv(key_iv.1);
// Compute j0 blocks.
let j0s = records
.iter()
.map(|rec| {
let block = aes.alloc_ctr_block(vm).map_err(TagProofError::vm)?;
let explicit_nonce: [u8; 8] =
rec.explicit_nonce
.clone()
.try_into()
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength {
expected: 8,
actual: explicit_nonce.len(),
})?;
vm.assign(block.explicit_nonce, explicit_nonce)
.map_err(TagProofError::vm)?;
vm.commit(block.explicit_nonce).map_err(TagProofError::vm)?;
// j0's counter is set to 1.
vm.assign(block.counter, 1u32.to_be_bytes())
.map_err(TagProofError::vm)?;
vm.commit(block.counter).map_err(TagProofError::vm)?;
let j0 = vm.decode(block.output).map_err(TagProofError::vm)?;
Ok(j0)
})
.collect::<Result<Vec<_>, TagProofError>>()?;
let mac_key = vm.decode(mac_key).map_err(TagProofError::vm)?;
Ok(TagProof {
j0s,
records,
mac_key,
})
}
/// Proof of tag verification.
#[derive(Debug)]
#[must_use]
pub struct TagProof {
/// The j0 block for each record.
j0s: Vec<DecodeFutureTyped<BitVec, [u8; 16]>>,
records: Vec<Record>,
/// The MAC key for tag computation.
mac_key: DecodeFutureTyped<BitVec, [u8; 16]>,
}
impl TagProof {
/// Verifies the proof.
pub fn verify(self) -> Result<(), TagProofError> {
let Self {
j0s,
mut mac_key,
records,
} = self;
let mac_key = mac_key
.try_recv()
.map_err(TagProofError::vm)?
.ok_or_else(|| ErrorRepr::NotDecoded)?;
for (mut j0, rec) in j0s.into_iter().zip(records) {
let j0 = j0
.try_recv()
.map_err(TagProofError::vm)?
.ok_or_else(|| ErrorRepr::NotDecoded)?;
let aad = make_tls12_aad(rec.seq, rec.typ, rec.version, rec.ciphertext.len());
let ghash_tag = ghash(aad.as_ref(), &rec.ciphertext, &mac_key);
let record_tag = match rec.tag.as_ref() {
Some(tag) => tag,
None => {
// This will never happen, since we only call this method
// for proofs where the records' tags are known.
return Err(ErrorRepr::UnknownTag.into());
}
};
if *record_tag
!= ghash_tag
.into_iter()
.zip(j0.into_iter())
.map(|(a, b)| a ^ b)
.collect::<Vec<_>>()
{
return Err(ErrorRepr::InvalidTag.into());
}
}
Ok(())
}
}
/// Error for [`J0Proof`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TagProofError(#[from] ErrorRepr);
impl TagProofError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("j0 proof error: {0}")]
enum ErrorRepr {
#[error("value was not decoded")]
NotDecoded,
#[error("VM error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("tag does not match expected")]
InvalidTag,
#[error("tag is not known")]
UnknownTag,
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
ExplicitNonceLength { expected: usize, actual: usize },
}

View File

@@ -6,21 +6,21 @@ use mpz_memory_core::{
};
use mpz_vm_core::{Vm, VmError};
use rangeset::Intersection;
use tls_core::msgs::enums::ContentType;
use tls_core::msgs::enums::{ContentType, ProtocolVersion};
use tlsn_core::transcript::{Direction, Idx, PartialTranscript, Transcript};
/// A transcript of sent and received TLS records.
/// A transcript of TLS records sent and received by the prover.
#[derive(Debug, Default, Clone)]
pub struct TlsTranscript {
/// Records sent by the prover.
/// Sent records.
pub sent: Vec<Record>,
/// Records received by the prover.
/// Received records.
pub recv: Vec<Record>,
}
impl TlsTranscript {
/// Returns the application data transcript.
pub fn to_transcript(&self) -> Result<Transcript, IncompleteTranscript> {
pub fn to_transcript(&self) -> Result<Transcript, TlsTranscriptError> {
let mut sent = Vec::new();
let mut recv = Vec::new();
@@ -32,7 +32,7 @@ impl TlsTranscript {
let plaintext = record
.plaintext
.as_ref()
.ok_or(IncompleteTranscript {})?
.ok_or(ErrorRepr::IncompleteTranscript {})?
.clone();
sent.extend_from_slice(&plaintext);
}
@@ -45,7 +45,7 @@ impl TlsTranscript {
let plaintext = record
.plaintext
.as_ref()
.ok_or(IncompleteTranscript {})?
.ok_or(ErrorRepr::IncompleteTranscript {})?
.clone();
recv.extend_from_slice(&plaintext);
}
@@ -54,7 +54,7 @@ impl TlsTranscript {
}
/// Returns the application data transcript references.
pub fn to_transcript_refs(&self) -> Result<TranscriptRefs, IncompleteTranscript> {
pub fn to_transcript_refs(&self) -> Result<TranscriptRefs, TlsTranscriptError> {
let mut sent = Vec::new();
let mut recv = Vec::new();
@@ -66,7 +66,7 @@ impl TlsTranscript {
let plaintext_ref = record
.plaintext_ref
.as_ref()
.ok_or(IncompleteTranscript {})?;
.ok_or(ErrorRepr::IncompleteTranscript {})?;
sent.push(*plaintext_ref);
}
@@ -78,7 +78,7 @@ impl TlsTranscript {
let plaintext_ref = record
.plaintext_ref
.as_ref()
.ok_or(IncompleteTranscript {})?;
.ok_or(ErrorRepr::IncompleteTranscript {})?;
recv.push(*plaintext_ref);
}
@@ -101,6 +101,10 @@ pub struct Record {
pub explicit_nonce: Vec<u8>,
/// Ciphertext.
pub ciphertext: Vec<u8>,
/// Tag.
pub tag: Option<Vec<u8>>,
/// Version.
pub version: ProtocolVersion,
}
opaque_debug::implement!(Record);
@@ -167,10 +171,17 @@ impl TranscriptRefs {
}
}
/// Error for [`TranscriptRefs::from_transcript`].
/// Error for [`TlsTranscript`].
#[derive(Debug, thiserror::Error)]
#[error("not all application plaintext was committed to in the TLS transcript")]
pub struct IncompleteTranscript {}
#[error(transparent)]
pub struct TlsTranscriptError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("TLS transcript error")]
enum ErrorRepr {
#[error("not all application plaintext was committed to in the TLS transcript")]
IncompleteTranscript {},
}
/// Decodes the transcript.
pub fn decode_transcript(

View File

@@ -52,7 +52,6 @@ aes = { workspace = true }
aes-gcm = { workspace = true }
ctr = { workspace = true }
ghash_rc = { package = "ghash", version = "0.5" }
cipher-crate = { package = "cipher", version = "0.4" }
tokio = { workspace = true, features = ["sync"] }
pin-project-lite = { workspace = true }
@@ -61,13 +60,14 @@ mpz-ole = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true }
mpz-garble = { workspace = true }
cipher-crate = { package = "cipher", version = "0.4" }
generic-array = { workspace = true }
rand_chacha = { workspace = true }
rstest = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }
rand_chacha = { workspace = true }
generic-array = { workspace = true }
uid-mux = { workspace = true, features = ["serio", "test-utils"] }
rstest = { workspace = true }

View File

@@ -17,7 +17,7 @@ const PROTOCOL_RECORD_COUNT_RECV: usize = 2;
///
/// Accurately estimating a good default is challenging as we do not
/// know exactly how much data will be packed into each record in advance.
fn default_record_count(max_data: usize) -> usize {
pub(crate) fn default_record_count(max_data: usize) -> usize {
// We assume a minimum of 8 records for the first 4KB.
const MIN: usize = 8;
@@ -46,12 +46,16 @@ pub struct Config {
pub(crate) max_sent_records: usize,
/// Maximum number of sent bytes.
pub(crate) max_sent: usize,
/// Maximum number of received TLS records. Data is transmitted in records
/// up to 16KB long.
pub(crate) max_recv_records: usize,
/// Maximum number of received TLS records to be decrypted online, i.e
/// while the TLS connection is active.
///
/// Data is transmitted in records up to 16KB long.
pub(crate) max_recv_records_online: usize,
/// Maximum number of received bytes which will be decrypted while
/// the TLS connection is active. Data which can be decrypted after the TLS
/// connection will be decrypted for free.
/// the TLS connection is active.
///
/// Data which can be decrypted after the TLS connection will be
/// decrypted for free.
pub(crate) max_recv_online: usize,
/// Maximum number of received bytes.
#[allow(unused)]
@@ -101,9 +105,9 @@ impl ConfigBuilder {
let max_sent_records = self
.max_sent_records
.unwrap_or_else(|| PROTOCOL_RECORD_COUNT_SENT + default_record_count(max_sent));
let max_recv_records = self
.max_recv_records
.unwrap_or_else(|| PROTOCOL_RECORD_COUNT_RECV + default_record_count(max_recv));
let max_recv_records_online = self
.max_recv_records_online
.unwrap_or_else(|| PROTOCOL_RECORD_COUNT_RECV + default_record_count(max_recv_online));
let prf = self.prf.unwrap_or(PrfMode::Reduced);
@@ -111,7 +115,7 @@ impl ConfigBuilder {
defer_decryption,
max_sent_records,
max_sent,
max_recv_records,
max_recv_records_online,
max_recv_online,
max_recv,
prf,

View File

@@ -104,7 +104,7 @@ impl MpcTlsFollower {
return Err(MpcTlsError::state("must be in init state to allocate"));
};
let (keys, cf_vd, sf_vd) = {
let (keys, cf_vd, sf_vd, sw_mac_key) = {
let vm = &mut (*vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?);
@@ -121,21 +121,29 @@ impl MpcTlsFollower {
let cf_vd = vm.decode(cf_vd).map_err(MpcTlsError::alloc)?;
let sf_vd = vm.decode(sf_vd).map_err(MpcTlsError::alloc)?;
record_layer.alloc(
let server_write_mac_key = record_layer.alloc(
vm,
self.config.max_sent_records,
self.config.max_recv_records,
self.config.max_recv_records_online,
self.config.max_sent,
self.config.max_recv_online,
self.config.max_recv,
)?;
(keys, cf_vd, sf_vd)
(keys, cf_vd, sf_vd, server_write_mac_key)
};
let keys: SessionKeys = SessionKeys {
client_write_key: keys.client_write_key,
client_write_iv: keys.client_iv,
server_write_key: keys.server_write_key,
server_write_iv: keys.server_iv,
server_write_mac_key: sw_mac_key,
};
self.state = State::Setup {
vm,
keys: keys.into(),
keys: keys.clone(),
ke,
prf,
record_layer,
@@ -143,7 +151,7 @@ impl MpcTlsFollower {
sf_vd,
};
Ok(keys.into())
Ok(keys)
}
/// Preprocesses the connection.
@@ -470,7 +478,7 @@ fn validate_transcript(
)));
}
} else {
return Err(MpcTlsError::record_layer("no records were sent"));
return Err(MpcTlsError::record_layer("client finished was not sent"));
}
// Make sure the server finished verify data message was consistent.
@@ -498,7 +506,9 @@ fn validate_transcript(
)));
}
} else {
return Err(MpcTlsError::record_layer("no records were received"));
return Err(MpcTlsError::record_layer(
"server finished was not received",
));
}
// Verify last record sent was either application data or close notify.

View File

@@ -62,7 +62,7 @@ pub struct MpcTlsLeader {
}
impl MpcTlsLeader {
/// Create a new leader instance
/// Creates a new leader instance.
pub fn new<CS, CR>(
config: Config,
ctx: Context,
@@ -155,19 +155,27 @@ impl MpcTlsLeader {
let cf_vd = vm_lock.decode(cf_vd).map_err(MpcTlsError::alloc)?;
let sf_vd = vm_lock.decode(sf_vd).map_err(MpcTlsError::alloc)?;
record_layer.alloc(
let server_write_mac_key = record_layer.alloc(
&mut (*vm_lock),
self.config.max_sent_records,
self.config.max_recv_records,
self.config.max_recv_records_online,
self.config.max_sent,
self.config.max_recv_online,
self.config.max_recv,
)?;
let keys: SessionKeys = SessionKeys {
client_write_key: keys.client_write_key,
client_write_iv: keys.client_iv,
server_write_key: keys.server_write_key,
server_write_iv: keys.server_iv,
server_write_mac_key,
};
self.state = State::Setup {
ctx,
vm,
keys: keys.into(),
keys: keys.clone(),
ke,
prf,
record_layer,
@@ -176,7 +184,7 @@ impl MpcTlsLeader {
client_random,
};
Ok(keys.into())
Ok(keys)
}
/// Preprocesses the connection.

View File

@@ -58,17 +58,8 @@ pub struct SessionKeys {
pub server_write_key: Array<U8, 16>,
/// Server write IV.
pub server_write_iv: Array<U8, 4>,
}
impl From<hmac_sha256::SessionKeys> for SessionKeys {
fn from(keys: hmac_sha256::SessionKeys) -> Self {
Self {
client_write_key: keys.client_write_key,
client_write_iv: keys.client_iv,
server_write_key: keys.server_write_key,
server_write_iv: keys.server_iv,
}
}
/// Server write MAC key.
pub server_write_mac_key: Array<U8, 16>,
}
/// MPC-TLS Leader output.
@@ -99,7 +90,8 @@ pub struct LeaderOutput {
pub struct FollowerData {
/// Server ephemeral public key.
pub server_key: PublicKey,
/// TLS transcript.
/// TLS transcript in which the received records are unauthenticated
/// from the follower's perspective.
pub transcript: TlsTranscript,
/// TLS session keys.
pub keys: SessionKeys,

View File

@@ -1,7 +1,7 @@
//! TLS record layer.
pub(crate) mod aead;
mod aes_ctr;
mod aes_gcm;
mod decrypt;
mod encrypt;
@@ -26,7 +26,7 @@ use tokio::sync::Mutex;
use tracing::{debug, instrument};
use crate::{
record_layer::{aes_ctr::AesCtr, decrypt::DecryptOp, encrypt::EncryptOp},
record_layer::{aes_gcm::AesGcm, decrypt::DecryptOp, encrypt::EncryptOp},
MpcTlsError, Role, Vm,
};
pub(crate) use decrypt::DecryptMode;
@@ -59,7 +59,7 @@ enum State {
sent_records: Vec<Record>,
recv_records: Vec<Record>,
},
Complete,
Complete {},
Error,
}
@@ -76,7 +76,7 @@ pub(crate) struct RecordLayer {
read_seq: u64,
encrypter: Arc<Mutex<MpcAesGcm>>,
decrypt: Arc<Mutex<MpcAesGcm>>,
aes_ctr: AesCtr,
aes_gcm: AesGcm,
state: State,
/// Whether the record layer has started processing application data.
started: bool,
@@ -108,7 +108,7 @@ impl RecordLayer {
read_seq: 0,
encrypter: Arc::new(Mutex::new(encrypt)),
decrypt: Arc::new(Mutex::new(decrypt)),
aes_ctr: AesCtr::new(role),
aes_gcm: AesGcm::new(role),
state: State::Init,
started: false,
sent: 0,
@@ -124,7 +124,8 @@ impl RecordLayer {
}
}
/// Allocates resources for the record layer.
/// Allocates resources for the record layer, returning a reference
/// to the server write MAC key.
///
/// # Arguments
///
@@ -143,7 +144,7 @@ impl RecordLayer {
sent_len: usize,
recv_len_online: usize,
recv_len: usize,
) -> Result<(), MpcTlsError> {
) -> Result<Array<U8, 16>, MpcTlsError> {
let State::Init = self.state.take() else {
return Err(MpcTlsError::other("record layer is already allocated"));
};
@@ -176,7 +177,7 @@ impl RecordLayer {
Role::Follower => None,
};
self.aes_ctr.alloc(vm)?;
self.aes_gcm.alloc(vm)?;
self.max_sent += sent_len;
self.max_recv_online += recv_len_online;
@@ -188,7 +189,7 @@ impl RecordLayer {
recv_records: Vec::new(),
};
Ok(())
decrypt.ghash_key().map_err(MpcTlsError::record_layer)
}
pub(crate) async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), MpcTlsError> {
@@ -236,7 +237,7 @@ impl RecordLayer {
encrypt.set_iv(client_iv);
decrypt.set_key(server_write_key);
decrypt.set_iv(server_iv);
self.aes_ctr.set_key(server_write_key, server_iv);
self.aes_gcm.set_key(server_write_key, server_iv);
Ok(())
}
@@ -475,12 +476,14 @@ impl RecordLayer {
for (op, pending) in encrypt_ops.into_iter().zip(pending_encrypt) {
let ciphertext = pending.output.try_encrypt()?;
let tag = tags.as_mut().and_then(Vec::pop);
self.encrypted_buffer.push_back(EncryptedRecord {
typ: op.typ,
version: op.version,
explicit_nonce: op.explicit_nonce.clone(),
ciphertext: ciphertext.clone(),
tag: tags.as_mut().and_then(Vec::pop),
tag: tag.clone(),
});
sent_records.push(Record {
@@ -490,6 +493,8 @@ impl RecordLayer {
plaintext_ref: pending.plaintext_ref,
explicit_nonce: op.explicit_nonce,
ciphertext,
tag,
version: op.version,
});
}
@@ -508,12 +513,16 @@ impl RecordLayer {
plaintext_ref: None,
explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext,
tag: Some(op.tag),
version: op.version,
});
}
Ok(())
}
/// Commits to the record layer, returning a transcript in which the
/// received records are unauthenticated from the follower's perspective.
pub(crate) async fn commit(
&mut self,
ctx: &mut Context,
@@ -547,28 +556,16 @@ impl RecordLayer {
let buffered_ops = take(&mut self.decrypt_buffer);
// Verify tags of buffered ciphertexts.
let verify_tags = decrypt::verify_tags(&mut (*vm), &mut decrypter, &buffered_ops)?;
vm.execute_all(ctx)
.await
.map_err(MpcTlsError::record_layer)?;
verify_tags
.run(ctx)
.await
.map_err(MpcTlsError::record_layer)?;
// Reveal decrypt key to the leader.
self.aes_ctr.decode_key(&mut (*vm))?;
// Reveal decryption key to the leader.
self.aes_gcm.decode_key(&mut (*vm))?;
vm.flush(ctx).await.map_err(MpcTlsError::record_layer)?;
self.aes_ctr.finish_decode()?;
self.aes_gcm.finish_decode()?;
let pending_decrypts = decrypt::decrypt_local(
self.role,
&mut (*vm),
&mut decrypter,
&mut self.aes_ctr,
&mut self.aes_gcm,
&buffered_ops,
)?;
@@ -591,10 +588,12 @@ impl RecordLayer {
plaintext_ref: None,
explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext,
tag: Some(op.tag),
version: op.version,
});
}
self.state = State::Complete;
self.state = State::Complete {};
Ok(TlsTranscript {
sent: sent_records,

View File

@@ -1,5 +1,16 @@
use std::{future::Future, sync::Arc};
use cipher::{aes::Aes128, Cipher, CtrBlock, Keystream};
use mpz_common::{Context, Flush};
use mpz_fields::gf2_128::Gf2_128;
use mpz_memory_core::{
binary::{Binary, U8},
Vector,
};
use mpz_share_conversion::ShareConvert;
use mpz_vm_core::{prelude::*, Vm};
use tracing::instrument;
use crate::{
decode::OneTimePadShared,
record_layer::{
@@ -11,16 +22,6 @@ use crate::{
},
Role,
};
use cipher::{aes::Aes128, Cipher, CtrBlock, Keystream};
use mpz_common::{Context, Flush};
use mpz_fields::gf2_128::Gf2_128;
use mpz_memory_core::{
binary::{Binary, U8},
Vector,
};
use mpz_share_conversion::ShareConvert;
use mpz_vm_core::{prelude::*, Vm};
use tracing::instrument;
const START_CTR: u32 = 2;
@@ -33,14 +34,16 @@ enum State {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
ghash_key: OneTimePadShared<[u8; 16]>,
ghash_key_share: OneTimePadShared<[u8; 16]>,
ghash: Box<dyn Ghash + Send + Sync>,
ghash_key: Array<U8, 16>,
},
Ready {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
ghash: Arc<dyn Ghash + Send + Sync>,
ghash_key: Array<U8, 16>,
},
Error,
}
@@ -96,7 +99,7 @@ impl MpcAesGcm {
ghash.alloc()?;
let ghash_key = self.aes.alloc_block(vm, zero_block)?;
let ghash_key = OneTimePadShared::<[u8; 16]>::new(self.role, ghash_key, vm)?;
let ghash_key_share = OneTimePadShared::<[u8; 16]>::new(self.role, ghash_key, vm)?;
// Allocate J0 secret sharing for GHASH.
let mut j0s = Vec::with_capacity(records);
@@ -129,6 +132,7 @@ impl MpcAesGcm {
keystream,
j0s,
ghash,
ghash_key_share,
ghash_key,
};
@@ -158,6 +162,7 @@ impl MpcAesGcm {
input,
keystream,
j0s,
ghash_key_share,
mut ghash,
ghash_key,
} = self.state.take()
@@ -165,7 +170,7 @@ impl MpcAesGcm {
return Err(AeadError::state("must be in setup state to set up"));
};
let key = ghash_key.await.map_err(AeadError::tag)?;
let key = ghash_key_share.await.map_err(AeadError::tag)?;
ghash.set_key(key.to_vec())?;
ghash.setup(ctx).await?;
@@ -174,6 +179,7 @@ impl MpcAesGcm {
keystream,
j0s,
ghash: Arc::from(ghash),
ghash_key,
};
Ok(())
@@ -300,6 +306,22 @@ impl MpcAesGcm {
Ok(keystream.to_vector(vm, len)?)
}
/// Returns the VM reference to the GHASH key.
#[instrument(level = "debug", skip_all, err)]
pub(crate) fn ghash_key(&mut self) -> Result<Array<U8, 16>, AeadError> {
let key = match self.state {
State::Setup { ghash_key, .. } => ghash_key,
State::Ready { ghash_key, .. } => ghash_key,
_ => {
return Err(AeadError::state(
"must be in setup or ready state to return ghash key",
))
}
};
Ok(key)
}
/// Computes tags for the provided ciphertext. See
/// [`verify_tags`](MpcAesGcm::verify_tags) for a method that verifies an
/// tags instead.

View File

@@ -259,29 +259,6 @@ fn compute_shares(key: Gf2_128, odd_powers: &[Gf2_128]) -> Vec<Gf2_128> {
shares
}
/// Builds padded data for GHASH.
pub(crate) fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
let associated_data_bitlen = (aad.len() as u64) * 8;
let text_bitlen = (ciphertext.len() as u64) * 8;
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);
data.extend(aad);
data.extend(ciphertext);
data.extend_from_slice(&len_block.to_be_bytes());
data
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TagShare([u8; 16]);

View File

@@ -4,11 +4,12 @@ use async_trait::async_trait;
use futures::{stream::FuturesOrdered, StreamExt as _};
use mpz_common::{Context, Task};
use serio::{stream::IoStreamExt, SinkExt};
use tlsn_common::ghash::build_ghash_data;
use crate::{
decode::OneTimePadShared,
record_layer::aead::{
ghash::{build_ghash_data, Ghash, TagShare},
ghash::{Ghash, TagShare},
AeadError,
},
Role,

View File

@@ -5,11 +5,12 @@ use futures::{stream::FuturesOrdered, StreamExt};
use mpz_common::{Context, Task};
use mpz_core::commit::{Decommitment, HashCommit};
use serio::{stream::IoStreamExt, SinkExt};
use tlsn_common::ghash::build_ghash_data;
use crate::{
decode::OneTimePadShared,
record_layer::aead::{
ghash::{build_ghash_data, Ghash, TagShare},
ghash::{Ghash, TagShare},
AeadError,
},
Role,
@@ -26,6 +27,7 @@ pub(crate) struct VerifyTagData {
pub(crate) struct VerifyTags {
role: Role,
data: Vec<VerifyTagData>,
/// MPC implementation to use for computing GHASH.
ghash: Arc<dyn Ghash + Send + Sync>,
}
@@ -64,6 +66,7 @@ impl Task for VerifyTags {
let mut tag_shares = Vec::with_capacity(data.len());
let mut tags = Vec::with_capacity(data.len());
for (mut tag_share, data) in j0_shares.into_iter().zip(data) {
let ghash_share = ghash
.compute(&build_ghash_data(data.aad, data.ciphertext))

View File

@@ -1,4 +1,4 @@
use cipher_crate::{KeyIvInit, StreamCipher as _, StreamCipherSeek};
use aes_gcm::{aead::AeadMutInPlace, Aes128Gcm, NewAead};
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
binary::{Binary, U8},
@@ -9,8 +9,6 @@ use rand::RngCore;
use crate::{MpcTlsError, Role};
type LocalAesCtr = ctr::Ctr32BE<aes::Aes128>;
enum State {
Init,
Alloc {
@@ -38,14 +36,14 @@ impl State {
}
}
pub(crate) struct AesCtr {
pub(crate) struct AesGcm {
role: Role,
key: Option<Array<U8, 16>>,
iv: Option<Array<U8, 4>>,
state: State,
}
impl AesCtr {
impl AesGcm {
pub(crate) fn new(role: Role) -> Self {
Self {
role,
@@ -131,6 +129,8 @@ impl AesCtr {
Ok(())
}
/// Finishes the decoding of key and IV.
#[allow(clippy::type_complexity)]
pub(crate) fn finish_decode(&mut self) -> Result<(), MpcTlsError> {
let State::Decode {
mut masked_key,
@@ -181,7 +181,9 @@ impl AesCtr {
pub(crate) fn decrypt(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
aad: Vec<u8>,
mut ciphertext: Vec<u8>,
tag: Vec<u8>,
) -> Result<Vec<u8>, MpcTlsError> {
let State::Ready { key, iv, .. } = &self.state else {
Err(MpcTlsError::record_layer(
@@ -198,29 +200,22 @@ impl AesCtr {
let key = key.as_ref().expect("leader knows key");
let iv = iv.as_ref().expect("leader knows iv");
let explicit_nonce: [u8; 8] =
explicit_nonce
.try_into()
.map_err(|explicit_nonce: Vec<_>| {
MpcTlsError::record_layer(format!(
"incorrect explicit nonce length: {} != 8",
explicit_nonce.len()
))
})?;
let mut aes_gcm = Aes128Gcm::new(key.into());
let mut full_iv = [0u8; 16];
let mut full_iv = [0u8; 12];
full_iv[..4].copy_from_slice(iv);
full_iv[4..12].copy_from_slice(&explicit_nonce);
let mut aes = LocalAesCtr::new(key.into(), &full_iv.into());
aes_gcm
.decrypt_in_place_detached(
(&full_iv).into(),
&aad,
&mut ciphertext,
tag.as_slice().into(),
)
.map_err(|_| MpcTlsError::record_layer("tag verification failed"))?;
// Skip the first 32 bytes of the keystream to match the AES-GCM implementation.
aes.seek(32);
let mut plaintext = ciphertext;
aes.apply_keystream(&mut plaintext);
Ok(plaintext)
Ok(ciphertext)
}
}
@@ -230,17 +225,18 @@ mod tests {
use aes_gcm::{aead::AeadMutInPlace, Aes128Gcm, NewAead};
#[test]
fn test_aes_ctr_local() {
fn test_aes_gcm_local() {
let key = [0u8; 16];
let iv = [42u8; 4];
let explicit_nonce = [69u8; 8];
let aad = [33u8; 13];
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&iv);
nonce[4..].copy_from_slice(&explicit_nonce);
let mut aes_ctr = AesCtr::new(Role::Leader);
aes_ctr.state = State::Ready {
let mut aes_gcm_local = AesGcm::new(Role::Leader);
aes_gcm_local.state = State::Ready {
key: Some(key),
iv: Some(iv),
};
@@ -250,12 +246,17 @@ mod tests {
let msg = b"hello world";
let mut ciphertext = msg.to_vec();
_ = aes_gcm
.encrypt_in_place_detached(&nonce.into(), &[], &mut ciphertext)
let tag = aes_gcm
.encrypt_in_place_detached(&nonce.into(), &aad, &mut ciphertext)
.unwrap();
let decrypted = aes_ctr
.decrypt(explicit_nonce.to_vec(), ciphertext)
let decrypted = aes_gcm_local
.decrypt(
explicit_nonce.to_vec(),
aad.to_vec(),
ciphertext,
tag.to_vec(),
)
.unwrap();
assert_eq!(msg, decrypted.as_slice());

View File

@@ -7,7 +7,7 @@ use tls_core::msgs::enums::{ContentType, ProtocolVersion};
use crate::{
record_layer::{
aead::{MpcAesGcm, VerifyTags},
aes_ctr::AesCtr,
aes_gcm::AesGcm,
TagData,
},
MpcTlsError, Role,
@@ -98,7 +98,7 @@ pub(crate) fn decrypt_local(
role: Role,
vm: &mut dyn Vm<Binary>,
mpc_decrypter: &mut MpcAesGcm,
local_decrypter: &mut AesCtr,
local_decrypter: &mut AesGcm,
ops: &[DecryptOp],
) -> Result<Vec<PendingDecrypt>, MpcTlsError> {
let mut pending_decrypt = Vec::with_capacity(ops.len());
@@ -106,8 +106,12 @@ pub(crate) fn decrypt_local(
match op.mode {
DecryptMode::Private => {
let plaintext = if let Role::Leader = role {
let plaintext = local_decrypter
.decrypt(op.explicit_nonce.clone(), op.ciphertext.clone())?;
let plaintext = local_decrypter.decrypt(
op.explicit_nonce.clone(),
op.aad.clone(),
op.ciphertext.clone(),
op.tag.clone(),
)?;
Some(plaintext)
} else {
None

View File

@@ -42,6 +42,7 @@ async fn mpc_tls_test() {
async fn leader_task(mut leader: MpcTlsLeader) {
leader.alloc().unwrap();
leader.preprocess().await.unwrap();
let (leader_ctrl, leader_fut) = leader.run();

View File

@@ -51,8 +51,8 @@ impl ProverConfig {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records) = self.protocol_config.max_recv_records() {
builder.max_recv_records(max_recv_records);
if let Some(max_recv_records_online) = self.protocol_config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Bandwidth = self.protocol_config.network() {

View File

@@ -1,6 +1,6 @@
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
use tlsn_common::{encoding::EncodingError, zk_aes::ZkAesCtrError};
use tlsn_common::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]

View File

@@ -32,8 +32,9 @@ use tlsn_common::{
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
transcript::{decode_transcript, Record, TlsTranscript},
zk_aes::ZkAesCtr,
zk_aes_ctr::ZkAesCtr,
Role,
};
use tlsn_core::{
@@ -49,7 +50,7 @@ use tlsn_core::{
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{debug, info_span, instrument, Instrument, Span};
use tracing::{debug, info, info_span, instrument, Instrument, Span};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
@@ -111,14 +112,14 @@ impl Prover<state::Initialized> {
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
// Allocate resources for MPC-TLS in VM.
// Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?;
translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
// Allocate for committing to plaintext.
let mut zk_aes = ZkAesCtr::new(Role::Prover);
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes.alloc(
let mut zk_aes_ctr = ZkAesCtr::new(Role::Prover);
zk_aes_ctr.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr.alloc(
&mut (*vm.try_lock().expect("VM is not locked").zk()),
self.config.protocol_config().max_recv_data(),
)?;
@@ -136,7 +137,7 @@ impl Prover<state::Initialized> {
mux_ctrl,
mux_fut,
mpc_tls,
zk_aes,
zk_aes_ctr,
keys,
vm,
},
@@ -162,7 +163,7 @@ impl Prover<state::Setup> {
mux_ctrl,
mut mux_fut,
mpc_tls,
mut zk_aes,
mut zk_aes_ctr,
keys,
vm,
..
@@ -207,31 +208,23 @@ impl Prover<state::Setup> {
Ok::<_, ProverError>(())
};
let (_, (mut ctx, mut data)) = futures::try_join!(
info!("starting MPC-TLS");
let (_, (mut ctx, mut data, ..)) = futures::try_join!(
conn_fut,
mpc_fut.in_current_span().map_err(ProverError::from)
)?;
info!("finished MPC-TLS");
{
let mut vm = vm.try_lock().expect("VM should not be locked");
translate_transcript(&mut data.transcript, &vm)?;
// Prove received plaintext. Prover drops the proof output, as they trust
// themselves.
_ = commit_records(
&mut (*vm.zk()),
&mut zk_aes,
data.transcript
.recv
.iter_mut()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
debug!("finalizing mpc");
// Finalize DEAP and execute the plaintext proofs.
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
@@ -240,6 +233,38 @@ impl Prover<state::Setup> {
debug!("mpc finalized");
}
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut vm,
(data.keys.server_write_key, data.keys.server_write_iv),
data.keys.server_write_mac_key,
data.transcript.recv.clone(),
)
.map_err(ProverError::zk)?;
// Prove received plaintext. Prover drops the proof output, as
// they trust themselves.
_ = commit_records(
&mut vm,
&mut zk_aes_ctr,
data.transcript
.recv
.iter_mut()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
let transcript = data
.transcript
.to_transcript()
@@ -286,12 +311,6 @@ impl Prover<state::Setup> {
}),
};
// Pull out ZK VM.
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
Ok(Prover {
config: self.config,
span: self.span,
@@ -628,6 +647,9 @@ fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(ProverError::mpc)?;
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.map_err(ProverError::mpc)?;
Ok(())
}

View File

@@ -8,7 +8,7 @@ use mpc_tls::{MpcTlsLeader, SessionKeys};
use tlsn_common::{
mux::{MuxControl, MuxFuture},
transcript::TranscriptRefs,
zk_aes::ZkAesCtr,
zk_aes_ctr::ZkAesCtr,
};
use tlsn_core::{
connection::{ConnectionInfo, ServerCertData},
@@ -29,7 +29,7 @@ pub struct Setup {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
pub(crate) zk_aes: ZkAesCtr,
pub(crate) zk_aes_ctr: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
}

View File

@@ -20,8 +20,6 @@ const MAX_SENT_DATA: usize = 1 << 12;
const MAX_SENT_RECORDS: usize = 4;
// Maximum number of bytes that can be received by prover from server
const MAX_RECV_DATA: usize = 1 << 14;
// Maximum number of application records received by prover from server
const MAX_RECV_RECORDS: usize = 6;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore]
@@ -58,7 +56,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(notary_socke
.max_sent_data(MAX_SENT_DATA)
.max_sent_records(MAX_SENT_RECORDS)
.max_recv_data(MAX_RECV_DATA)
.max_recv_records(MAX_RECV_RECORDS)
.build()
.unwrap(),
)

View File

@@ -54,8 +54,8 @@ impl VerifierConfig {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records) = protocol_config.max_recv_records() {
builder.max_recv_records(max_recv_records);
if let Some(max_recv_records_online) = protocol_config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Bandwidth = protocol_config.network() {

View File

@@ -1,6 +1,6 @@
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
use tlsn_common::{encoding::EncodingError, zk_aes::ZkAesCtrError};
use tlsn_common::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
/// Error for [`Verifier`](crate::Verifier).
#[derive(Debug, thiserror::Error)]

View File

@@ -28,8 +28,9 @@ use tlsn_common::{
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
transcript::{decode_transcript, verify_transcript, Record, TlsTranscript},
zk_aes::ZkAesCtr,
zk_aes_ctr::ZkAesCtr,
Role,
};
use tlsn_core::{
@@ -121,9 +122,9 @@ impl Verifier<state::Initialized> {
translate_keys(&mut keys, &vm.try_lock().expect("VM is not locked"))?;
// Allocate for committing to plaintext.
let mut zk_aes = ZkAesCtr::new(Role::Verifier);
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes.alloc(
let mut zk_aes_ctr = ZkAesCtr::new(Role::Verifier);
zk_aes_ctr.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr.alloc(
&mut (*vm.try_lock().expect("VM is not locked").zk()),
protocol_config.max_recv_data(),
)?;
@@ -142,7 +143,7 @@ impl Verifier<state::Initialized> {
mux_fut,
delta,
mpc_tls,
zk_aes,
zk_aes_ctr,
_keys: keys,
vm,
},
@@ -210,7 +211,7 @@ impl Verifier<state::Setup> {
mut mux_fut,
delta,
mpc_tls,
mut zk_aes,
mut zk_aes_ctr,
vm,
..
} = self.state;
@@ -227,6 +228,7 @@ impl Verifier<state::Setup> {
FollowerData {
server_key,
mut transcript,
keys,
..
},
) = mux_fut.poll_with(mpc_tls.run()).await?;
@@ -238,31 +240,55 @@ impl Verifier<state::Setup> {
translate_transcript(&mut transcript, &vm)?;
// Prepare for the prover to prove received plaintext.
let proof = commit_records(
&mut (*vm.zk()),
&mut zk_aes,
transcript
.recv
.iter_mut()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
debug!("finalizing mpc");
// Finalize DEAP and execute the plaintext proofs.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(VerifierError::mpc)?;
debug!("mpc finalized");
// Verify the plaintext proofs.
proof.verify().map_err(VerifierError::zk)?;
}
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
// Prepare for the prover to prove tag verification of the received
// records.
let tag_proof = verify_tags(
&mut vm,
(keys.server_write_key, keys.server_write_iv),
keys.server_write_mac_key,
transcript.recv.clone(),
)
.map_err(VerifierError::zk)?;
// Prepare for the prover to prove received plaintext.
let proof = commit_records(
&mut vm,
&mut zk_aes_ctr,
transcript
.recv
.iter_mut()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?;
// Verify the tags.
// After the verification, the entire TLS trancript becomes
// authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?;
// Verify the plaintext proofs.
proof.verify().map_err(VerifierError::zk)?;
let sent = transcript
.sent
.iter()
@@ -286,12 +312,6 @@ impl Verifier<state::Setup> {
transcript_length: TranscriptLength { sent, received },
};
// Pull out ZK VM.
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
Ok(Verifier {
config: self.config,
span: self.span,
@@ -576,6 +596,9 @@ fn translate_keys<Mpc, Zk>(
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(VerifierError::mpc)?;
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.map_err(VerifierError::mpc)?;
Ok(())
}

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::correlated::Delta;
use tlsn_common::{
mux::{MuxControl, MuxFuture},
transcript::TranscriptRefs,
zk_aes::ZkAesCtr,
zk_aes_ctr::ZkAesCtr,
};
use tlsn_core::connection::{ConnectionInfo, ServerEphemKey};
use tlsn_deap::Deap;
@@ -23,13 +23,13 @@ pub struct Initialized;
opaque_debug::implement!(Initialized);
/// State after MPC setup has completed.
/// State after setup has completed.
pub struct Setup {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) delta: Delta,
pub(crate) mpc_tls: MpcTlsFollower,
pub(crate) zk_aes: ZkAesCtr,
pub(crate) zk_aes_ctr: ZkAesCtr,
pub(crate) _keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
}

View File

@@ -11,7 +11,7 @@ pub struct ProverConfig {
pub max_recv_data: usize,
pub defer_decryption_from_start: Option<bool>,
pub max_sent_records: Option<usize>,
pub max_recv_records: Option<usize>,
pub max_recv_records_online: Option<usize>,
pub network: NetworkSetting,
}
@@ -30,8 +30,8 @@ impl From<ProverConfig> for tlsn_prover::ProverConfig {
builder.max_sent_records(value);
}
if let Some(value) = value.max_recv_records {
builder.max_recv_records(value);
if let Some(value) = value.max_recv_records_online {
builder.max_recv_records_online(value);
}
if let Some(value) = value.defer_decryption_from_start {

View File

@@ -8,7 +8,7 @@ pub struct VerifierConfig {
pub max_sent_data: usize,
pub max_recv_data: usize,
pub max_sent_records: Option<usize>,
pub max_recv_records: Option<usize>,
pub max_recv_records_online: Option<usize>,
}
impl From<VerifierConfig> for tlsn_verifier::VerifierConfig {
@@ -22,8 +22,8 @@ impl From<VerifierConfig> for tlsn_verifier::VerifierConfig {
builder.max_sent_records(value);
}
if let Some(value) = value.max_recv_records {
builder.max_recv_records(value);
if let Some(value) = value.max_recv_records_online {
builder.max_recv_records_online(value);
}
let validator = builder.build().unwrap();