feat(tlsn): disclose encryption key (#1010)

Co-authored-by: th4s <th4s@metavoid.xyz>
This commit is contained in:
sinu.eth
2025-10-10 08:32:50 -07:00
committed by GitHub
parent bf1cf2302a
commit 6b9f44e7e5
15 changed files with 559 additions and 508 deletions

2
Cargo.lock generated
View File

@@ -7150,6 +7150,8 @@ dependencies = [
name = "tlsn"
version = "0.1.0-alpha.13-pre"
dependencies = [
"aes 0.8.4",
"ctr 0.9.2",
"derive_builder 0.12.0",
"futures",
"ghash 0.5.1",

View File

@@ -41,6 +41,8 @@ mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
aes = { workspace = true }
ctr = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
opaque-debug = { workspace = true }

View File

@@ -1,166 +0,0 @@
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{DecodeFutureTyped, MemoryExt, ViewExt, binary::Binary};
use mpz_vm_core::Vm;
use rangeset::{Difference, RangeSet, Subset};
use crate::{
commit::transcript::ReferenceMap,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
pub(crate) fn prove_plaintext(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
plaintext: &[u8],
ranges: &RangeSet<usize>,
public: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
assert!(public.is_subset(ranges), "public is not a subset of ranges");
if ranges.is_empty() {
return Ok(ReferenceMap::default());
}
let (plaintext_map, ciphertext_map) = zk_aes
.alloc_plaintext(vm, ranges)
.map_err(ErrorRepr::ZkAesCtr)?;
for (range, chunk) in plaintext_map
.index(&ranges.difference(public))
.expect("map contains all ranges")
.iter()
{
vm.mark_private(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (range, chunk) in plaintext_map
.index(public)
.expect("map contains all ranges")
.iter()
{
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (_, chunk) in ciphertext_map.iter() {
drop(vm.decode(*chunk).map_err(PlaintextAuthError::vm)?);
}
Ok(plaintext_map)
}
pub(crate) fn verify_plaintext(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
plaintext: &[u8],
ciphertext: &[u8],
ranges: &RangeSet<usize>,
public: &RangeSet<usize>,
) -> Result<(ReferenceMap, PlaintextProof), PlaintextAuthError> {
assert!(public.is_subset(ranges), "public is not a subset of ranges");
if ranges.is_empty() {
return Ok((
ReferenceMap::default(),
PlaintextProof {
ciphertexts: vec![],
},
));
}
let (plaintext_map, ciphertext_map) = zk_aes
.alloc_plaintext(vm, ranges)
.map_err(ErrorRepr::ZkAesCtr)?;
for (_, chunk) in plaintext_map
.index(&ranges.difference(public))
.expect("map contains all ranges")
.iter()
{
vm.mark_blind(*chunk).map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (range, chunk) in plaintext_map
.index(public)
.expect("map contains all ranges")
.iter()
{
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
let mut ciphertexts = Vec::new();
for (range, chunk) in ciphertext_map
.index(ranges)
.expect("map contains all ranges")
.iter()
{
ciphertexts.push((
ciphertext[range].to_vec(),
vm.decode(*chunk).map_err(PlaintextAuthError::vm)?,
));
}
Ok((plaintext_map, PlaintextProof { ciphertexts }))
}
#[derive(Debug, thiserror::Error)]
#[error("plaintext authentication error: {0}")]
pub(crate) struct PlaintextAuthError(#[from] ErrorRepr);
impl PlaintextAuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk aes ctr error: {0}")]
ZkAesCtr(ZkAesCtrError),
#[error("missing decoding")]
MissingDecoding,
#[error("invalid ciphertext")]
InvalidCiphertext,
}
#[must_use]
pub(crate) struct PlaintextProof {
// (expected, actual)
#[allow(clippy::type_complexity)]
ciphertexts: Vec<(Vec<u8>, DecodeFutureTyped<BitVec, Vec<u8>>)>,
}
impl PlaintextProof {
pub(crate) fn verify(self) -> Result<(), PlaintextAuthError> {
let Self {
ciphertexts: ciphertext,
} = self;
for (expected, mut actual) in ciphertext {
let actual = actual
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(PlaintextAuthError(ErrorRepr::MissingDecoding))?;
if actual != expected {
return Err(PlaintextAuthError(ErrorRepr::InvalidCiphertext));
}
}
Ok(())
}
}

View File

@@ -4,16 +4,15 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub(crate) mod commit;
pub mod config;
pub(crate) mod context;
pub(crate) mod encoding;
pub(crate) mod ghash;
pub(crate) mod map;
pub(crate) mod mux;
pub mod prover;
pub(crate) mod tag;
pub(crate) mod transcript_internal;
pub mod verifier;
pub(crate) mod zk_aes_ctr;
pub use tlsn_attestation as attestation;
pub use tlsn_core::{connection, hash, transcript};

View File

@@ -3,15 +3,6 @@ use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8};
use rangeset::RangeSet;
pub(crate) type ReferenceMap = RangeMap<Vector<U8>>;
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
pub(crate) struct TranscriptRefs {
pub(crate) sent: ReferenceMap,
pub(crate) recv: ReferenceMap,
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> {
map: Vec<(usize, T)>,
@@ -44,6 +35,13 @@ where
Self { map }
}
/// Returns the keys of the map.
pub(crate) fn keys(&self) -> impl Iterator<Item = Range<usize>> {
self.map
.iter()
.map(|(idx, item)| *idx..*idx + item.length())
}
/// Returns the length of the map.
pub(crate) fn len(&self) -> usize {
self.map.iter().map(|(_, item)| item.length()).sum()

View File

@@ -2,7 +2,7 @@ use std::{error::Error, fmt};
use mpc_tls::MpcTlsError;
use crate::encoding::EncodingError;
use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]

View File

@@ -13,10 +13,15 @@ use tlsn_core::{
};
use crate::{
commit::{auth::prove_plaintext, hash::prove_hash, transcript::TranscriptRefs},
encoding::{self, MacStore},
prover::ProverError,
zk_aes_ctr::ZkAesCtr,
transcript_internal::{
TranscriptRefs,
auth::prove_plaintext,
commit::{
encoding::{self, MacStore},
hash::prove_hash,
},
},
};
pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
@@ -61,67 +66,51 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
.await
.map_err(ProverError::from)?;
let mut auth_sent_ranges = RangeSet::default();
let mut auth_recv_ranges = RangeSet::default();
let (reveal_sent, reveal_recv) = config.reveal().cloned().unwrap_or_default();
auth_sent_ranges.union_mut(&reveal_sent);
auth_recv_ranges.union_mut(&reveal_recv);
let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default());
if let Some(commit_config) = config.transcript_commit() {
commit_config
.iter_hash()
.for_each(|((direction, idx), _)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
commit_config
.iter_encoding()
.for_each(|(direction, idx)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
}
let mut zk_aes_sent = ZkAesCtr::new(
keys.client_write_key,
keys.client_write_iv,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let mut zk_aes_recv = ZkAesCtr::new(
keys.server_write_key,
keys.server_write_iv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let sent_refs = prove_plaintext(
vm,
&mut zk_aes_sent,
transcript.sent(),
&auth_sent_ranges,
&reveal_sent,
)
.map_err(ProverError::commit)?;
let recv_refs = prove_plaintext(
vm,
&mut zk_aes_recv,
transcript.received(),
&auth_recv_ranges,
&reveal_recv,
)
.map_err(ProverError::commit)?;
let transcript_refs = TranscriptRefs {
sent: sent_refs,
recv: recv_refs,
sent: prove_plaintext(
vm,
keys.client_write_key,
keys.client_write_iv,
transcript.sent(),
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
&reveal_sent,
&commit_sent,
)
.map_err(ProverError::commit)?,
recv: prove_plaintext(
vm,
keys.server_write_key,
keys.server_write_iv,
transcript.received(),
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
&reveal_recv,
&commit_recv,
)
.map_err(ProverError::commit)?,
};
let hash_commitments = if let Some(commit_config) = config.transcript_commit()

View File

@@ -0,0 +1,16 @@
pub(crate) mod auth;
pub(crate) mod commit;
use mpz_memory_core::{Vector, binary::U8};
use crate::map::RangeMap;
/// Maps transcript ranges to VM references.
pub(crate) type ReferenceMap = RangeMap<Vector<U8>>;
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
pub(crate) struct TranscriptRefs {
pub(crate) sent: ReferenceMap,
pub(crate) recv: ReferenceMap,
}

View File

@@ -0,0 +1,455 @@
use std::sync::Arc;
use aes::Aes128;
use ctr::{
Ctr32BE,
cipher::{KeyIvInit, StreamCipher, StreamCipherSeek},
};
use mpz_circuits::circuits::{AES128, xor};
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
Array, DecodeFutureTyped, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::{Difference, RangeSet, Union};
use tlsn_core::transcript::Record;
use crate::transcript_internal::ReferenceMap;
pub(crate) fn prove_plaintext<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
plaintext: &[u8],
records: impl IntoIterator<Item = &'a Record>,
reveal: &RangeSet<usize>,
commit: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
let is_reveal_all = reveal == (0..plaintext.len());
let alloc_ranges = if is_reveal_all {
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
let records = RecordParams::from_iter(records).collect::<Vec<_>>();
if is_reveal_all {
drop(vm.decode(key).map_err(PlaintextAuthError::vm)?);
drop(vm.decode(iv).map_err(PlaintextAuthError::vm)?);
for (range, slice) in plaintext_refs.iter() {
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
} else {
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
.iter()
{
vm.mark_private(*slice).map_err(PlaintextAuthError::vm)?;
}
for (_, slice) in plaintext_refs
.index(reveal)
.expect("all ranges are allocated")
.iter()
{
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
}
for (range, slice) in plaintext_refs.iter() {
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
let ciphertext = alloc_ciphertext(vm, key, iv, plaintext_refs.clone(), &records)?;
for (_, slice) in ciphertext.iter() {
drop(vm.decode(*slice).map_err(PlaintextAuthError::vm)?);
}
}
Ok(plaintext_refs)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn verify_plaintext<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
plaintext: &'a [u8],
ciphertext: &'a [u8],
records: impl IntoIterator<Item = &'a Record>,
reveal: &RangeSet<usize>,
commit: &RangeSet<usize>,
) -> Result<(ReferenceMap, PlaintextProof<'a>), PlaintextAuthError> {
let is_reveal_all = reveal == (0..plaintext.len());
let alloc_ranges = if is_reveal_all {
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
let records = RecordParams::from_iter(records).collect::<Vec<_>>();
let plaintext_proof = if is_reveal_all {
let key = vm.decode(key).map_err(PlaintextAuthError::vm)?;
let iv = vm.decode(iv).map_err(PlaintextAuthError::vm)?;
for (range, slice) in plaintext_refs.iter() {
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
PlaintextProof(ProofInner::WithKey {
key,
iv,
records,
plaintext,
ciphertext,
})
} else {
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
.iter()
{
vm.mark_blind(*slice).map_err(PlaintextAuthError::vm)?;
}
for (range, slice) in plaintext_refs
.index(reveal)
.expect("all ranges are allocated")
.iter()
{
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
}
for (_, slice) in plaintext_refs.iter() {
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
let ciphertext_map = alloc_ciphertext(vm, key, iv, plaintext_refs.clone(), &records)?;
let mut ciphertexts = Vec::new();
for (range, chunk) in ciphertext_map.iter() {
ciphertexts.push((
&ciphertext[range],
vm.decode(*chunk).map_err(PlaintextAuthError::vm)?,
));
}
PlaintextProof(ProofInner::WithZk { ciphertexts })
};
Ok((plaintext_refs, plaintext_proof))
}
fn alloc_plaintext(
vm: &mut dyn Vm<Binary>,
ranges: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
let len = ranges.len();
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_ciphertext<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
plaintext: ReferenceMap,
records: impl IntoIterator<Item = &'a RecordParams>,
) -> Result<ReferenceMap, PlaintextAuthError> {
let ranges = RangeSet::from(plaintext.keys().collect::<Vec<_>>());
let keystream = alloc_keystream(vm, key, iv, &ranges, records)?;
let mut builder = Call::builder(Arc::new(xor(ranges.len() * 8)));
for (_, slice) in plaintext.iter() {
builder = builder.arg(*slice);
}
for slice in keystream {
builder = builder.arg(slice);
}
let call = builder.build().expect("call should be valid");
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_keystream<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
ranges: &RangeSet<usize>,
records: impl IntoIterator<Item = &'a RecordParams>,
) -> Result<Vec<Vector<U8>>, PlaintextAuthError> {
let mut keystream = Vec::new();
let mut pos = 0;
let mut range_iter = ranges.iter_ranges();
let mut current_range = range_iter.next();
for record in records {
let mut explicit_nonce = None;
let mut current_block = None;
loop {
let Some(range) = current_range.take().or_else(|| range_iter.next()) else {
return Ok(keystream);
};
if range.start >= pos + record.len {
current_range = Some(range);
break;
}
let explicit_nonce = if let Some(explicit_nonce) = explicit_nonce {
explicit_nonce
} else {
let nonce = alloc_explicit_nonce(vm, record.explicit_nonce.clone())?;
explicit_nonce = Some(nonce);
nonce
};
const BLOCK_SIZE: usize = 16;
let block_num = (range.start - pos) / BLOCK_SIZE;
let block = if let Some((current_block_num, block)) = current_block.take()
&& current_block_num == block_num
{
block
} else {
let block = alloc_block(vm, key, iv, explicit_nonce, block_num)?;
current_block = Some((block_num, block));
block
};
let start = (range.start - pos) % BLOCK_SIZE;
let end = (start + range.len()).min(BLOCK_SIZE);
let len = end - start;
keystream.push(block.get(start..end).expect("range is checked"));
// If the range is larger than a block, process the tail.
if range.len() > BLOCK_SIZE {
current_range = Some(range.start + len..range.end);
}
}
pos += record.len;
}
Err(ErrorRepr::OutOfBounds.into())
}
fn alloc_explicit_nonce(
vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>,
) -> Result<Vector<U8>, PlaintextAuthError> {
const EXPLICIT_NONCE_LEN: usize = 8;
let nonce = vm
.alloc_vec::<U8>(EXPLICIT_NONCE_LEN)
.map_err(PlaintextAuthError::vm)?;
vm.mark_public(nonce).map_err(PlaintextAuthError::vm)?;
vm.assign(nonce, explicit_nonce)
.map_err(PlaintextAuthError::vm)?;
vm.commit(nonce).map_err(PlaintextAuthError::vm)?;
Ok(nonce)
}
fn alloc_block(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
explicit_nonce: Vector<U8>,
block: usize,
) -> Result<Vector<U8>, PlaintextAuthError> {
let ctr: Array<U8, 4> = vm.alloc().map_err(PlaintextAuthError::vm)?;
vm.mark_public(ctr).map_err(PlaintextAuthError::vm)?;
const START_CTR: u32 = 2;
vm.assign(ctr, (START_CTR + block as u32).to_be_bytes())
.map_err(PlaintextAuthError::vm)?;
vm.commit(ctr).map_err(PlaintextAuthError::vm)?;
let block: Array<U8, 16> = vm
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(ctr)
.build()
.expect("call should be valid"),
)
.map_err(PlaintextAuthError::vm)?;
Ok(Vector::from(block))
}
struct RecordParams {
explicit_nonce: Vec<u8>,
len: usize,
}
impl RecordParams {
fn from_iter<'a>(records: impl IntoIterator<Item = &'a Record>) -> impl Iterator<Item = Self> {
records.into_iter().map(|record| Self {
explicit_nonce: record.explicit_nonce.clone(),
len: record.ciphertext.len(),
})
}
}
#[must_use]
pub(crate) struct PlaintextProof<'a>(ProofInner<'a>);
impl<'a> PlaintextProof<'a> {
pub(crate) fn verify(self) -> Result<(), PlaintextAuthError> {
match self.0 {
ProofInner::WithKey {
mut key,
mut iv,
records,
plaintext,
ciphertext,
} => {
let key = key
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(ErrorRepr::MissingDecoding)?;
let iv = iv
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(ErrorRepr::MissingDecoding)?;
verify_plaintext_with_key(key, iv, &records, plaintext, ciphertext)?;
}
ProofInner::WithZk { ciphertexts } => {
for (expected, mut actual) in ciphertexts {
let actual = actual
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(PlaintextAuthError(ErrorRepr::MissingDecoding))?;
if actual != expected {
return Err(PlaintextAuthError(ErrorRepr::InvalidPlaintext));
}
}
}
}
Ok(())
}
}
enum ProofInner<'a> {
WithKey {
key: DecodeFutureTyped<BitVec, [u8; 16]>,
iv: DecodeFutureTyped<BitVec, [u8; 4]>,
records: Vec<RecordParams>,
plaintext: &'a [u8],
ciphertext: &'a [u8],
},
WithZk {
// (expected, actual)
#[allow(clippy::type_complexity)]
ciphertexts: Vec<(&'a [u8], DecodeFutureTyped<BitVec, Vec<u8>>)>,
},
}
fn verify_plaintext_with_key<'a>(
key: [u8; 16],
iv: [u8; 4],
records: impl IntoIterator<Item = &'a RecordParams>,
plaintext: &[u8],
ciphertext: &[u8],
) -> Result<(), PlaintextAuthError> {
let mut pos = 0;
let mut text = Vec::new();
for record in records {
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(&iv);
full_iv[4..12].copy_from_slice(&record.explicit_nonce[..8]);
const START_CTR: u32 = 2;
let mut cipher = Ctr32BE::<Aes128>::new(&key.into(), &full_iv.into());
cipher
.try_seek(START_CTR * 16)
.expect("start counter is less than keystream length");
text.clear();
text.extend_from_slice(&plaintext[pos..pos + record.len]);
cipher.apply_keystream(&mut text);
if text != ciphertext[pos..pos + record.len] {
return Err(PlaintextAuthError(ErrorRepr::InvalidPlaintext));
}
pos += record.len;
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
#[error("plaintext authentication error: {0}")]
pub(crate) struct PlaintextAuthError(#[from] ErrorRepr);
impl PlaintextAuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("plaintext out of bounds of records. This should never happen and is an internal bug.")]
OutOfBounds,
#[error("missing decoding")]
MissingDecoding,
#[error("plaintext does not match ciphertext")]
InvalidPlaintext,
}

View File

@@ -1,5 +1,4 @@
//! Plaintext commitment and proof of encryption.
pub(crate) mod auth;
pub(crate) mod encoding;
pub(crate) mod hash;
pub(crate) mod transcript;

View File

@@ -23,7 +23,10 @@ use tlsn_core::{
},
};
use crate::commit::transcript::{Item, RangeMap, ReferenceMap};
use crate::{
map::{Item, RangeMap},
transcript_internal::ReferenceMap,
};
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;

View File

@@ -18,7 +18,7 @@ use tlsn_core::{
},
};
use crate::{Role, commit::transcript::TranscriptRefs};
use crate::{Role, transcript_internal::TranscriptRefs};
/// Future which will resolve to the committed hash values.
#[derive(Debug)]

View File

@@ -1,7 +1,9 @@
use crate::encoding::EncodingError;
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Verifier`](crate::Verifier).
#[derive(Debug, thiserror::Error)]
pub struct VerifierError {

View File

@@ -12,10 +12,15 @@ use tlsn_core::{
};
use crate::{
commit::{auth::verify_plaintext, hash::verify_hash, transcript::TranscriptRefs},
encoding::{self, KeyStore},
transcript_internal::{
TranscriptRefs,
auth::verify_plaintext,
commit::{
encoding::{self, KeyStore},
hash::verify_hash,
},
},
verifier::VerifierError,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
@@ -65,59 +70,47 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
None
};
let mut auth_sent_ranges = RangeSet::default();
let mut auth_recv_ranges = RangeSet::default();
auth_sent_ranges.union_mut(transcript.sent_authed());
auth_recv_ranges.union_mut(transcript.received_authed());
let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default());
if let Some(commit_config) = transcript_commit.as_ref() {
commit_config
.iter_hash()
.for_each(|(direction, idx, _)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
if let Some((sent, recv)) = commit_config.encoding() {
auth_sent_ranges.union_mut(sent);
auth_recv_ranges.union_mut(recv);
commit_sent.union_mut(sent);
commit_recv.union_mut(recv);
}
}
let mut zk_aes_sent = ZkAesCtr::new(
let (sent_refs, sent_proof) = verify_plaintext(
vm,
keys.client_write_key,
keys.client_write_iv,
transcript.sent_unsafe(),
&ciphertext_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let mut zk_aes_recv = ZkAesCtr::new(
keys.server_write_key,
keys.server_write_iv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let (sent_refs, sent_proof) = verify_plaintext(
vm,
&mut zk_aes_sent,
transcript.sent_unsafe(),
&ciphertext_sent,
&auth_sent_ranges,
transcript.sent_authed(),
&commit_sent,
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = verify_plaintext(
vm,
&mut zk_aes_recv,
keys.server_write_key,
keys.server_write_iv,
transcript.received_unsafe(),
&ciphertext_recv,
&auth_recv_ranges,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
transcript.received_authed(),
&commit_recv,
)
.map_err(VerifierError::zk)?;

View File

@@ -1,241 +0,0 @@
use std::{ops::Range, sync::Arc};
use mpz_circuits::circuits::{AES128, xor};
use mpz_memory_core::{
Array, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::RangeSet;
use tlsn_core::transcript::Record;
use crate::commit::transcript::ReferenceMap;
/// ZK AES-CTR encryption.
#[derive(Debug)]
pub(crate) struct ZkAesCtr {
key: Array<U8, 16>,
iv: Array<U8, 4>,
records: Vec<(usize, RecordState)>,
total_len: usize,
}
impl ZkAesCtr {
/// Creates a new instance.
pub(crate) fn new<'record>(
key: Array<U8, 16>,
iv: Array<U8, 4>,
records: impl IntoIterator<Item = &'record Record>,
) -> Self {
let mut pos = 0;
let mut record_state = Vec::new();
for record in records {
record_state.push((
pos,
RecordState {
explicit_nonce: Some(record.explicit_nonce.clone()),
explicit_nonce_ref: None,
range: pos..pos + record.ciphertext.len(),
},
));
pos += record.ciphertext.len();
}
Self {
key,
iv,
records: record_state,
total_len: pos,
}
}
/// Allocates the plaintext for the provided ranges.
///
/// Returns a reference to the plaintext and the ciphertext.
pub(crate) fn alloc_plaintext(
&mut self,
vm: &mut dyn Vm<Binary>,
ranges: &RangeSet<usize>,
) -> Result<(ReferenceMap, ReferenceMap), ZkAesCtrError> {
let len = ranges.len();
if len > self.total_len {
return Err(ZkAesCtrError(ErrorRepr::TranscriptBounds {
len,
max: self.total_len,
}));
}
let plaintext = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
let keystream = self.alloc_keystream(vm, ranges)?;
let mut builder = Call::builder(Arc::new(xor(len * 8))).arg(plaintext);
for slice in keystream {
builder = builder.arg(slice);
}
let call = builder.build().expect("call should be valid");
let ciphertext: Vector<U8> = vm.call(call).map_err(ZkAesCtrError::vm)?;
let mut pos = 0;
let plaintext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
}));
let mut pos = 0;
let ciphertext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
}));
Ok((plaintext, ciphertext))
}
fn alloc_keystream(
&mut self,
vm: &mut dyn Vm<Binary>,
ranges: &RangeSet<usize>,
) -> Result<Vec<Vector<U8>>, ZkAesCtrError> {
let mut keystream = Vec::new();
let mut range_iter = ranges.iter_ranges();
let mut current_range = range_iter.next();
for (pos, record) in self.records.iter_mut() {
let pos = *pos;
let mut current_block = None;
loop {
let Some(range) = current_range.take().or_else(|| range_iter.next()) else {
return Ok(keystream);
};
if range.start >= record.range.end {
current_range = Some(range);
break;
}
const BLOCK_SIZE: usize = 16;
let block_num = (range.start - pos) / BLOCK_SIZE;
let block = if let Some((current_block_num, block)) = current_block.take()
&& current_block_num == block_num
{
block
} else {
let block = record.alloc_block(vm, self.key, self.iv, block_num)?;
current_block = Some((block_num, block));
block
};
let start = (range.start - pos) % BLOCK_SIZE;
let end = (start + range.len()).min(BLOCK_SIZE);
let len = end - start;
keystream.push(block.get(start..end).expect("range is checked"));
// If the range is larger than a block, process the tail.
if range.len() > BLOCK_SIZE {
current_range = Some(range.start + len..range.end);
}
}
}
unreachable!("plaintext length was checked");
}
}
#[derive(Debug)]
struct RecordState {
explicit_nonce: Option<Vec<u8>>,
range: Range<usize>,
explicit_nonce_ref: Option<Vector<U8>>,
}
impl RecordState {
fn alloc_explicit_nonce(
&mut self,
vm: &mut dyn Vm<Binary>,
) -> Result<Vector<U8>, ZkAesCtrError> {
if let Some(explicit_nonce) = self.explicit_nonce_ref {
Ok(explicit_nonce)
} else {
const EXPLICIT_NONCE_LEN: usize = 8;
let explicit_nonce_ref = vm
.alloc_vec::<U8>(EXPLICIT_NONCE_LEN)
.map_err(ZkAesCtrError::vm)?;
vm.mark_public(explicit_nonce_ref)
.map_err(ZkAesCtrError::vm)?;
vm.assign(
explicit_nonce_ref,
self.explicit_nonce
.take()
.expect("explicit nonce only set once"),
)
.map_err(ZkAesCtrError::vm)?;
vm.commit(explicit_nonce_ref).map_err(ZkAesCtrError::vm)?;
self.explicit_nonce_ref = Some(explicit_nonce_ref);
Ok(explicit_nonce_ref)
}
}
fn alloc_block(
&mut self,
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
block: usize,
) -> Result<Vector<U8>, ZkAesCtrError> {
let explicit_nonce = self.alloc_explicit_nonce(vm)?;
let ctr: Array<U8, 4> = vm.alloc().map_err(ZkAesCtrError::vm)?;
vm.mark_public(ctr).map_err(ZkAesCtrError::vm)?;
const START_CTR: u32 = 2;
vm.assign(ctr, (START_CTR + block as u32).to_be_bytes())
.map_err(ZkAesCtrError::vm)?;
vm.commit(ctr).map_err(ZkAesCtrError::vm)?;
let block: Array<U8, 16> = vm
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(ctr)
.build()
.expect("call should be valid"),
)
.map_err(ZkAesCtrError::vm)?;
Ok(Vector::from(block))
}
}
/// Error for [`ZkAesCtr`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct ZkAesCtrError(#[from] ErrorRepr);
impl ZkAesCtrError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("zk aes error")]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("transcript bounds exceeded: {len} > {max}")]
TranscriptBounds { len: usize, max: usize },
}