feat(tlsn): partial plaintext auth (#1006)

Co-authored-by: th4s <th4s@metavoid.xyz>
This commit is contained in:
sinu.eth
2025-10-09 11:22:23 -07:00
committed by GitHub
parent df8d79c152
commit 2884be17e0
22 changed files with 1542 additions and 1294 deletions

631
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -66,19 +66,19 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" } tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" } tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" } mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
rangeset = { version = "0.2" } rangeset = { version = "0.2" }
serio = { version = "0.2" } serio = { version = "0.2" }

View File

@@ -190,10 +190,10 @@ pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
enum VerifyConfigBuilderErrorRepr {} enum VerifyConfigBuilderErrorRepr {}
/// Payload sent to the verifier. /// Request to prove statements about the connection.
#[doc(hidden)] #[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ProvePayload { pub struct ProveRequest {
/// Handshake data. /// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>, pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data. /// Transcript data.

View File

@@ -2,7 +2,7 @@
use std::{collections::HashSet, fmt}; use std::{collections::HashSet, fmt};
use rangeset::ToRangeSet; use rangeset::{ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
@@ -114,7 +114,19 @@ impl TranscriptCommitConfig {
/// Returns a request for the transcript commitments. /// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest { pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest { TranscriptCommitRequest {
encoding: self.has_encoding, encoding: self.has_encoding.then(|| {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (dir, idx) in self.iter_encoding() {
match dir {
Direction::Sent => sent.union_mut(idx),
Direction::Received => recv.union_mut(idx),
}
}
(sent, recv)
}),
hash: self hash: self
.iter_hash() .iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)) .map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -289,14 +301,14 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments. /// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest { pub struct TranscriptCommitRequest {
encoding: bool, encoding: Option<(RangeSet<usize>, RangeSet<usize>)>,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>, hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
} }
impl TranscriptCommitRequest { impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested. /// Returns `true` if an encoding commitment is requested.
pub fn encoding(&self) -> bool { pub fn has_encoding(&self) -> bool {
self.encoding self.encoding.is_some()
} }
/// Returns `true` if a hash commitment is requested. /// Returns `true` if a hash commitment is requested.
@@ -308,6 +320,11 @@ impl TranscriptCommitRequest {
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> { pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
self.hash.iter() self.hash.iter()
} }
/// Returns the ranges of the encoding commitments.
pub fn encoding(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.encoding.as_ref()
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -31,6 +31,7 @@ web-spawn = { workspace = true, optional = true }
mpz-common = { workspace = true } mpz-common = { workspace = true }
mpz-core = { workspace = true } mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true } mpz-garble = { workspace = true }
mpz-garble-core = { workspace = true } mpz-garble-core = { workspace = true }
mpz-hash = { workspace = true } mpz-hash = { workspace = true }

View File

@@ -1,116 +1,5 @@
//! Plaintext commitment and proof of encryption. //! Plaintext commitment and proof of encryption.
pub(crate) mod auth;
pub(crate) mod hash; pub(crate) mod hash;
pub(crate) mod transcript; pub(crate) mod transcript;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
DecodeFutureTyped, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, prelude::*};
use tlsn_core::transcript::Record;
use crate::{
Role,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
/// Commits the plaintext of the provided records, returning a proof of
/// encryption.
///
/// Writes the plaintext VM reference to the provided records.
pub(crate) fn commit_records<'record>(
vm: &mut dyn Vm<Binary>,
aes: &mut ZkAesCtr,
records: impl IntoIterator<Item = &'record Record>,
) -> Result<(Vec<Vector<U8>>, RecordProof), RecordProofError> {
let mut plaintexts = Vec::new();
let mut ciphertexts = Vec::new();
for record in records {
let (plaintext_ref, ciphertext_ref) = aes
.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())
.map_err(ErrorRepr::Aes)?;
if let Role::Prover = aes.role() {
let Some(plaintext) = record.plaintext.clone() else {
return Err(ErrorRepr::MissingPlaintext.into());
};
vm.assign(plaintext_ref, plaintext)
.map_err(RecordProofError::vm)?;
}
vm.commit(plaintext_ref).map_err(RecordProofError::vm)?;
let ciphertext = vm.decode(ciphertext_ref).map_err(RecordProofError::vm)?;
plaintexts.push(plaintext_ref);
ciphertexts.push((ciphertext, record.ciphertext.clone()));
}
Ok((plaintexts, RecordProof { ciphertexts }))
}
/// Proof of encryption.
#[derive(Debug)]
#[must_use]
#[allow(clippy::type_complexity)]
pub(crate) struct RecordProof {
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
}
impl RecordProof {
/// Verifies the proof.
pub(crate) fn verify(self) -> Result<(), RecordProofError> {
let Self { ciphertexts } = self;
for (mut ciphertext, expected) in ciphertexts {
let ciphertext = ciphertext
.try_recv()
.map_err(RecordProofError::vm)?
.ok_or_else(|| ErrorRepr::NotDecoded)?;
if ciphertext != expected {
return Err(ErrorRepr::InvalidCiphertext.into());
}
}
Ok(())
}
}
/// Error for [`RecordProof`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct RecordProofError(#[from] ErrorRepr);
impl RecordProofError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
pub(crate) fn is_insufficient(&self) -> bool {
match &self.0 {
ErrorRepr::Aes(err) => err.is_insufficient(),
_ => false,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("record proof error: {0}")]
enum ErrorRepr {
#[error("VM error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk aes error: {0}")]
Aes(ZkAesCtrError),
#[error("plaintext is missing")]
MissingPlaintext,
#[error("ciphertext was not decoded")]
NotDecoded,
#[error("ciphertext does not match expected")]
InvalidCiphertext,
}

View File

@@ -0,0 +1,166 @@
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{DecodeFutureTyped, MemoryExt, ViewExt, binary::Binary};
use mpz_vm_core::Vm;
use rangeset::{Difference, RangeSet, Subset};
use crate::{
commit::transcript::ReferenceMap,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
pub(crate) fn prove_plaintext(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
plaintext: &[u8],
ranges: &RangeSet<usize>,
public: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
assert!(public.is_subset(ranges), "public is not a subset of ranges");
if ranges.is_empty() {
return Ok(ReferenceMap::default());
}
let (plaintext_map, ciphertext_map) = zk_aes
.alloc_plaintext(vm, ranges)
.map_err(ErrorRepr::ZkAesCtr)?;
for (range, chunk) in plaintext_map
.index(&ranges.difference(public))
.expect("map contains all ranges")
.iter()
{
vm.mark_private(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (range, chunk) in plaintext_map
.index(public)
.expect("map contains all ranges")
.iter()
{
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (_, chunk) in ciphertext_map.iter() {
drop(vm.decode(*chunk).map_err(PlaintextAuthError::vm)?);
}
Ok(plaintext_map)
}
pub(crate) fn verify_plaintext(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
plaintext: &[u8],
ciphertext: &[u8],
ranges: &RangeSet<usize>,
public: &RangeSet<usize>,
) -> Result<(ReferenceMap, PlaintextProof), PlaintextAuthError> {
assert!(public.is_subset(ranges), "public is not a subset of ranges");
if ranges.is_empty() {
return Ok((
ReferenceMap::default(),
PlaintextProof {
ciphertexts: vec![],
},
));
}
let (plaintext_map, ciphertext_map) = zk_aes
.alloc_plaintext(vm, ranges)
.map_err(ErrorRepr::ZkAesCtr)?;
for (_, chunk) in plaintext_map
.index(&ranges.difference(public))
.expect("map contains all ranges")
.iter()
{
vm.mark_blind(*chunk).map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (range, chunk) in plaintext_map
.index(public)
.expect("map contains all ranges")
.iter()
{
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
let mut ciphertexts = Vec::new();
for (range, chunk) in ciphertext_map
.index(ranges)
.expect("map contains all ranges")
.iter()
{
ciphertexts.push((
ciphertext[range].to_vec(),
vm.decode(*chunk).map_err(PlaintextAuthError::vm)?,
));
}
Ok((plaintext_map, PlaintextProof { ciphertexts }))
}
#[derive(Debug, thiserror::Error)]
#[error("plaintext authentication error: {0}")]
pub(crate) struct PlaintextAuthError(#[from] ErrorRepr);
impl PlaintextAuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk aes ctr error: {0}")]
ZkAesCtr(ZkAesCtrError),
#[error("missing decoding")]
MissingDecoding,
#[error("invalid ciphertext")]
InvalidCiphertext,
}
#[must_use]
pub(crate) struct PlaintextProof {
// (expected, actual)
#[allow(clippy::type_complexity)]
ciphertexts: Vec<(Vec<u8>, DecodeFutureTyped<BitVec, Vec<u8>>)>,
}
impl PlaintextProof {
pub(crate) fn verify(self) -> Result<(), PlaintextAuthError> {
let Self {
ciphertexts: ciphertext,
} = self;
for (expected, mut actual) in ciphertext {
let actual = actual
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(PlaintextAuthError(ErrorRepr::MissingDecoding))?;
if actual != expected {
return Err(PlaintextAuthError(ErrorRepr::InvalidCiphertext));
}
}
Ok(())
}
}

View File

@@ -149,9 +149,15 @@ fn hash_commit_inner(
hasher hasher
}; };
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") { let refs = match direction {
hasher.update(&plaintext); Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
} }
hasher.update(&blinder); hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)? hasher.finalize(vm).map_err(HashCommitError::hasher)?
} }
@@ -164,9 +170,14 @@ fn hash_commit_inner(
hasher hasher
}; };
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") { let refs = match direction {
Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
hasher hasher
.update(vm, &plaintext) .update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?; .map_err(HashCommitError::hasher)?;
} }
hasher hasher

View File

@@ -1,211 +1,205 @@
use mpz_memory_core::{ use std::ops::Range;
MemoryExt, Vector,
binary::{Binary, U8}, use mpz_memory_core::{Vector, binary::U8};
}; use rangeset::RangeSet;
use mpz_vm_core::{Vm, VmError};
use rangeset::{Intersection, RangeSet}; pub(crate) type ReferenceMap = RangeMap<Vector<U8>>;
use tlsn_core::transcript::{Direction, PartialTranscript};
/// References to the application plaintext in the transcript. /// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
pub(crate) struct TranscriptRefs { pub(crate) struct TranscriptRefs {
sent: Vec<Vector<U8>>, pub(crate) sent: ReferenceMap,
recv: Vec<Vector<U8>>, pub(crate) recv: ReferenceMap,
} }
impl TranscriptRefs { #[derive(Debug, Clone, PartialEq)]
pub(crate) fn new(sent: Vec<Vector<U8>>, recv: Vec<Vector<U8>>) -> Self { pub(crate) struct RangeMap<T> {
Self { sent, recv } map: Vec<(usize, T)>,
}
impl<T> Default for RangeMap<T>
where
T: Item,
{
fn default() -> Self {
Self { map: Vec::new() }
} }
}
/// Returns the sent plaintext references. impl<T> RangeMap<T>
pub(crate) fn sent(&self) -> &[Vector<U8>] { where
&self.sent T: Item,
} {
pub(crate) fn new(map: Vec<(usize, T)>) -> Self {
let mut pos = 0;
for (idx, item) in &map {
assert!(
*idx >= pos,
"items must be sorted by index and non-overlapping"
);
/// Returns the received plaintext references. pos = *idx + item.length();
pub(crate) fn recv(&self) -> &[Vector<U8>] {
&self.recv
}
/// Returns the transcript lengths.
pub(crate) fn len(&self) -> (usize, usize) {
let sent = self.sent.iter().map(|v| v.len()).sum();
let recv = self.recv.iter().map(|v| v.len()).sum();
(sent, recv)
}
/// Returns VM references for the given direction and index, otherwise
/// `None` if the index is out of bounds.
pub(crate) fn get(
&self,
direction: Direction,
idx: &RangeSet<usize>,
) -> Option<Vec<Vector<U8>>> {
if idx.is_empty() {
return Some(Vec::new());
} }
let refs = match direction { Self { map }
Direction::Sent => &self.sent, }
Direction::Received => &self.recv,
/// Returns the length of the map.
pub(crate) fn len(&self) -> usize {
self.map.iter().map(|(_, item)| item.length()).sum()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (Range<usize>, &T)> {
self.map
.iter()
.map(|(idx, item)| (*idx..*idx + item.length(), item))
}
pub(crate) fn get(&self, range: Range<usize>) -> Option<T::Slice<'_>> {
if range.start >= range.end {
return None;
}
// Find the item with the greatest start index <= range.start
let pos = match self.map.binary_search_by(|(idx, _)| idx.cmp(&range.start)) {
Ok(i) => i,
Err(0) => return None,
Err(i) => i - 1,
}; };
// Computes the transcript range for each reference. let (base, item) = &self.map[pos];
let mut start = 0;
let mut slice_iter = refs.iter().map(move |slice| {
let out = (slice, start..start + slice.len());
start += slice.len();
out
});
let mut slices = Vec::new(); item.slice(range.start - *base..range.end - *base)
let (mut slice, mut slice_range) = slice_iter.next()?; }
for range in idx.iter_ranges() {
loop {
if let Some(intersection) = slice_range.intersection(&range) {
let start = intersection.start - slice_range.start;
let end = intersection.end - slice_range.start;
slices.push(slice.get(start..end).expect("range should be in bounds"));
}
// Proceed to next range if the current slice extends beyond. Otherwise, proceed pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
// to the next slice. let mut map = Vec::new();
if range.end <= slice_range.end { for idx in idx.iter_ranges() {
break; let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
} else { Ok(i) => i,
(slice, slice_range) = slice_iter.next()?; Err(0) => return None,
} Err(i) => i - 1,
};
let (base, item) = self.map.get(pos)?;
if idx.start < *base || idx.end > *base + item.length() {
return None;
} }
let start = idx.start - *base;
let end = start + idx.len();
map.push((
idx.start,
item.slice(start..end)
.expect("slice length is checked")
.into(),
));
} }
Some(slices) Some(Self { map })
} }
} }
/// Decodes the transcript. impl<T> FromIterator<(usize, T)> for RangeMap<T>
pub(crate) fn decode_transcript( where
vm: &mut dyn Vm<Binary>, T: Item,
sent: &RangeSet<usize>, {
recv: &RangeSet<usize>, fn from_iter<I: IntoIterator<Item = (usize, T)>>(items: I) -> Self {
refs: &TranscriptRefs, let mut pos = 0;
) -> Result<(), VmError> { let mut map = Vec::new();
let sent_refs = refs.get(Direction::Sent, sent).expect("index is in bounds"); for (idx, item) in items {
let recv_refs = refs assert!(
.get(Direction::Received, recv) idx >= pos,
.expect("index is in bounds"); "items must be sorted by index and non-overlapping"
);
for slice in sent_refs.into_iter().chain(recv_refs) { pos = idx + item.length();
// Drop the future, we don't need it. map.push((idx, item));
drop(vm.decode(slice)?); }
Self { map }
} }
Ok(())
} }
/// Verifies a partial transcript. pub(crate) trait Item: Sized {
pub(crate) fn verify_transcript( type Slice<'a>: Into<Self>
vm: &mut dyn Vm<Binary>, where
transcript: &PartialTranscript, Self: 'a;
refs: &TranscriptRefs,
) -> Result<(), InconsistentTranscript> {
let sent_refs = refs
.get(Direction::Sent, transcript.sent_authed())
.expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, transcript.received_authed())
.expect("index is in bounds");
let mut authenticated_data = Vec::new(); fn length(&self) -> usize;
for data in sent_refs.into_iter().chain(recv_refs) {
let plaintext = vm
.get(data)
.expect("reference is valid")
.expect("plaintext is decoded");
authenticated_data.extend_from_slice(&plaintext);
}
let mut purported_data = Vec::with_capacity(authenticated_data.len()); fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>>;
for range in transcript.sent_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.sent_unsafe()[range]);
}
for range in transcript.received_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(InconsistentTranscript {});
}
Ok(())
} }
/// Error for [`verify_transcript`]. impl Item for Vector<U8> {
#[derive(Debug, thiserror::Error)] type Slice<'a> = Vector<U8>;
#[error("inconsistent transcript")]
pub(crate) struct InconsistentTranscript {} fn length(&self) -> usize {
self.len()
}
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
self.get(range)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::TranscriptRefs; use super::*;
use mpz_memory_core::{FromRaw, Slice, Vector, binary::U8};
use rangeset::RangeSet;
use std::ops::Range;
use tlsn_core::transcript::Direction;
// TRANSCRIPT_REFS: impl Item for Range<usize> {
// type Slice<'a> = Range<usize>;
// 48..96 -> 6 slots
// 112..176 -> 8 slots
// 240..288 -> 6 slots
// 352..392 -> 5 slots
// 440..480 -> 5 slots
const TRANSCRIPT_REFS: &[Range<usize>] = &[48..96, 112..176, 240..288, 352..392, 440..480];
const IDXS: &[Range<usize>] = &[0..4, 5..10, 14..16, 16..28]; fn length(&self) -> usize {
self.end - self.start
}
// 1. Take slots 0..4, 4 slots -> 48..80 (4) fn slice(&self, range: Range<usize>) -> Option<Self> {
// 2. Take slots 5..10, 5 slots -> 88..96 (1) + 112..144 (4) if range.end > self.end - self.start {
// 3. Take slots 14..16, 2 slots -> 240..256 (2) return None;
// 4. Take slots 16..28, 12 slots -> 256..288 (4) + 352..392 (5) + 440..464 (3) }
//
// 5. Merge slots 240..256 and 256..288 => 240..288 and get EXPECTED_REFS
const EXPECTED_REFS: &[Range<usize>] =
&[48..80, 88..96, 112..144, 240..288, 352..392, 440..464];
#[test] Some(range.start + self.start..range.end + self.start)
fn test_transcript_refs_get() {
let transcript_refs: Vec<Vector<U8>> = TRANSCRIPT_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
let transcript_refs = TranscriptRefs {
sent: transcript_refs.clone(),
recv: transcript_refs,
};
let vm_refs = transcript_refs
.get(Direction::Sent, &RangeSet::from(IDXS))
.unwrap();
let expected_refs: Vec<Vector<U8>> = EXPECTED_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
assert_eq!(
vm_refs.len(),
expected_refs.len(),
"Length of actual and expected refs are not equal"
);
for (&expected, actual) in expected_refs.iter().zip(vm_refs) {
assert_eq!(expected, actual);
} }
} }
#[test]
fn test_range_map() {
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
assert_eq!(map.get(0..4), Some(10..14));
assert_eq!(map.get(10..14), Some(20..24));
assert_eq!(map.get(20..22), Some(30..32));
assert_eq!(map.get(0..2), Some(10..12));
assert_eq!(map.get(11..13), Some(21..23));
assert_eq!(map.get(0..10), None);
assert_eq!(map.get(10..20), None);
assert_eq!(map.get(20..30), None);
}
#[test]
fn test_range_map_index() {
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
let idx = RangeSet::from([0..4, 10..14, 20..22]);
assert_eq!(map.index(&idx), Some(map.clone()));
let idx = RangeSet::from(25..30);
assert_eq!(map.index(&idx), None);
let idx = RangeSet::from(15..20);
assert_eq!(map.index(&idx), None);
let idx = RangeSet::from([1..3, 11..12, 13..14, 21..22]);
assert_eq!(
map.index(&idx),
Some(RangeMap::from_iter([
(1, 11..13),
(11, 21..22),
(13, 23..24),
(21, 31..32)
]))
);
}
} }

View File

@@ -13,7 +13,7 @@ use rangeset::RangeSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt}; use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{ use tlsn_core::{
hash::HashAlgorithm, hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{ transcript::{
Direction, Direction,
encoding::{ encoding::{
@@ -23,7 +23,7 @@ use tlsn_core::{
}, },
}; };
use crate::commit::transcript::TranscriptRefs; use crate::commit::transcript::{Item, RangeMap, ReferenceMap};
/// Bytes of encoding, per byte. /// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128; const ENCODING_SIZE: usize = 128;
@@ -34,145 +34,130 @@ struct Encodings {
recv: Vec<u8>, recv: Vec<u8>,
} }
/// Transfers the encodings using the provided seed and keys. /// Transfers encodings for the provided plaintext ranges.
/// pub(crate) async fn transfer<K: KeyStore>(
/// The keys must be consistent with the global delta used in the encodings.
pub(crate) async fn transfer<'a>(
ctx: &mut Context, ctx: &mut Context,
refs: &TranscriptRefs, store: &K,
delta: &Delta, sent: &ReferenceMap,
f: impl Fn(Vector<U8>) -> &'a [Key], recv: &ReferenceMap,
) -> Result<EncodingCommitment, EncodingError> { ) -> Result<EncodingCommitment, EncodingError> {
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes()); let secret = EncoderSecret::new(rand::rng().random(), store.delta().as_block().to_bytes());
let encoder = new_encoder(&secret); let encoder = new_encoder(&secret);
let sent_keys: Vec<u8> = refs // Collects the encodings for the provided plaintext ranges.
.sent() fn collect_encodings(
.iter() encoder: &impl Encoder,
.copied() store: &impl KeyStore,
.flat_map(&f) direction: Direction,
.flat_map(|key| key.as_block().as_bytes()) map: &ReferenceMap,
.copied() ) -> Vec<u8> {
.collect(); let mut encodings = Vec::with_capacity(map.len() * ENCODING_SIZE);
let recv_keys: Vec<u8> = refs for (range, chunk) in map.iter() {
.recv() let start = encodings.len();
.iter() encoder.encode_range(direction, range, &mut encodings);
.copied() let keys = store
.flat_map(&f) .get_keys(*chunk)
.flat_map(|key| key.as_block().as_bytes()) .expect("keys are present for provided plaintext ranges");
.copied() encodings[start..]
.collect(); .iter_mut()
.zip(keys.iter().flat_map(|key| key.as_block().as_bytes()))
.for_each(|(encoding, key)| {
*encoding ^= *key;
});
}
encodings
}
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0); let encodings = Encodings {
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0); sent: collect_encodings(&encoder, store, Direction::Sent, sent),
recv: collect_encodings(&encoder, store, Direction::Received, recv),
};
let mut sent_encoding = Vec::with_capacity(sent_keys.len()); let frame_limit = ctx
let mut recv_encoding = Vec::with_capacity(recv_keys.len()); .io()
.limit()
encoder.encode_range( .saturating_add(encodings.sent.len() + encodings.recv.len());
Direction::Sent, ctx.io_mut().with_limit(frame_limit).send(encodings).await?;
0..sent_keys.len() / ENCODING_SIZE,
&mut sent_encoding,
);
encoder.encode_range(
Direction::Received,
0..recv_keys.len() / ENCODING_SIZE,
&mut recv_encoding,
);
sent_encoding
.iter_mut()
.zip(sent_keys)
.for_each(|(enc, key)| *enc ^= key);
recv_encoding
.iter_mut()
.zip(recv_keys)
.for_each(|(enc, key)| *enc ^= key);
// Set frame limit and add some extra bytes cushion room.
let (sent, recv) = refs.len();
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
ctx.io_mut()
.with_limit(frame_limit)
.send(Encodings {
sent: sent_encoding,
recv: recv_encoding,
})
.await?;
let root = ctx.io_mut().expect_next().await?; let root = ctx.io_mut().expect_next().await?;
ctx.io_mut().send(secret.clone()).await?; ctx.io_mut().send(secret.clone()).await?;
Ok(EncodingCommitment { Ok(EncodingCommitment { root, secret })
root,
secret: secret.clone(),
})
} }
/// Receives the encodings using the provided MACs. /// Receives and commits to the encodings for the provided plaintext ranges.
/// pub(crate) async fn receive<M: MacStore>(
/// The MACs must be consistent with the global delta used in the encodings.
pub(crate) async fn receive<'a>(
ctx: &mut Context, ctx: &mut Context,
hasher: &(dyn HashAlgorithm + Send + Sync), store: &M,
refs: &TranscriptRefs, hash_alg: HashAlgId,
f: impl Fn(Vector<U8>) -> &'a [Mac], sent: &ReferenceMap,
recv: &ReferenceMap,
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>, idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> { ) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
// Set frame limit and add some extra bytes cushion room. let hasher: &(dyn HashAlgorithm + Send + Sync) = match hash_alg {
let (sent, recv) = refs.len(); HashAlgId::SHA256 => &Sha256::default(),
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit(); HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ErrorRepr::UnsupportedHashAlgorithm(alg).into());
}
};
let Encodings { mut sent, mut recv } = let (sent_len, recv_len) = (sent.len(), recv.len());
ctx.io_mut().with_limit(frame_limit).expect_next().await?; let frame_limit = ctx
.io()
.limit()
.saturating_add(ENCODING_SIZE * (sent_len + recv_len));
let encodings: Encodings = ctx.io_mut().with_limit(frame_limit).expect_next().await?;
let sent_macs: Vec<u8> = refs if encodings.sent.len() != sent_len * ENCODING_SIZE {
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
let recv_macs: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
if sent.len() != sent_macs.len() {
return Err(ErrorRepr::IncorrectMacCount { return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent, direction: Direction::Sent,
expected: sent_macs.len(), expected: sent_len,
got: sent.len(), got: encodings.sent.len() / ENCODING_SIZE,
} }
.into()); .into());
} }
if recv.len() != recv_macs.len() { if encodings.recv.len() != recv_len * ENCODING_SIZE {
return Err(ErrorRepr::IncorrectMacCount { return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received, direction: Direction::Received,
expected: recv_macs.len(), expected: recv_len,
got: recv.len(), got: encodings.recv.len() / ENCODING_SIZE,
} }
.into()); .into());
} }
sent.iter_mut() // Collects a map of plaintext ranges to their encodings.
.zip(sent_macs) fn collect_map(
.for_each(|(enc, mac)| *enc ^= mac); store: &impl MacStore,
recv.iter_mut() mut encodings: Vec<u8>,
.zip(recv_macs) map: &ReferenceMap,
.for_each(|(enc, mac)| *enc ^= mac); ) -> RangeMap<EncodingSlice> {
let mut encoding_map = Vec::new();
let mut pos = 0;
for (range, chunk) in map.iter() {
let macs = store
.get_macs(*chunk)
.expect("MACs are present for provided plaintext ranges");
let encoding = &mut encodings[pos..pos + range.len() * ENCODING_SIZE];
encoding
.iter_mut()
.zip(macs.iter().flat_map(|mac| mac.as_bytes()))
.for_each(|(encoding, mac)| {
*encoding ^= *mac;
});
let provider = Provider { sent, recv }; encoding_map.push((range.start, EncodingSlice::from(&(*encoding))));
pos += range.len() * ENCODING_SIZE;
}
RangeMap::new(encoding_map)
}
let provider = Provider {
sent: collect_map(store, encodings.sent, sent),
recv: collect_map(store, encodings.recv, recv),
};
let tree = EncodingTree::new(hasher, idxs, &provider)?; let tree = EncodingTree::new(hasher, idxs, &provider)?;
let root = tree.root(); let root = tree.root();
@@ -185,10 +170,36 @@ pub(crate) async fn receive<'a>(
Ok((commitment, tree)) Ok((commitment, tree))
} }
pub(crate) trait KeyStore {
fn delta(&self) -> &Delta;
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
}
impl KeyStore for crate::verifier::Zk {
fn delta(&self) -> &Delta {
crate::verifier::Zk::delta(self)
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
pub(crate) trait MacStore {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
}
impl MacStore for crate::prover::Zk {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[derive(Debug)] #[derive(Debug)]
struct Provider { struct Provider {
sent: Vec<u8>, sent: RangeMap<EncodingSlice>,
recv: Vec<u8>, recv: RangeMap<EncodingSlice>,
} }
impl EncodingProvider for Provider { impl EncodingProvider for Provider {
@@ -203,19 +214,39 @@ impl EncodingProvider for Provider {
Direction::Received => &self.recv, Direction::Received => &self.recv,
}; };
let start = range.start * ENCODING_SIZE; let encoding = encodings.get(range).ok_or(EncodingProviderError)?;
let end = range.end * ENCODING_SIZE;
if end > encodings.len() { dest.extend_from_slice(encoding);
return Err(EncodingProviderError);
}
dest.extend_from_slice(&encodings[start..end]);
Ok(()) Ok(())
} }
} }
#[derive(Debug)]
struct EncodingSlice(Vec<u8>);
impl From<&[u8]> for EncodingSlice {
fn from(value: &[u8]) -> Self {
Self(value.to_vec())
}
}
impl Item for EncodingSlice {
type Slice<'a>
= &'a [u8]
where
Self: 'a;
fn length(&self) -> usize {
self.0.len() / ENCODING_SIZE
}
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
self.0
.get(range.start * ENCODING_SIZE..range.end * ENCODING_SIZE)
}
}
/// Encoding protocol error. /// Encoding protocol error.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error(transparent)] #[error(transparent)]
@@ -234,6 +265,8 @@ enum ErrorRepr {
}, },
#[error("encoding tree error: {0}")] #[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError), EncodingTree(EncodingTreeError),
#[error("unsupported hash algorithm: {0}")]
UnsupportedHashAlgorithm(HashAlgId),
} }
impl From<std::io::Error> for EncodingError { impl From<std::io::Error> for EncodingError {

View File

@@ -9,7 +9,6 @@ pub mod config;
pub(crate) mod context; pub(crate) mod context;
pub(crate) mod encoding; pub(crate) mod encoding;
pub(crate) mod ghash; pub(crate) mod ghash;
pub(crate) mod msg;
pub(crate) mod mux; pub(crate) mod mux;
pub mod prover; pub mod prover;
pub(crate) mod tag; pub(crate) mod tag;

View File

@@ -1,15 +0,0 @@
//! Message types.
use serde::{Deserialize, Serialize};
use tlsn_core::connection::{HandshakeData, ServerName};
/// Message sent from Prover to Verifier to prove the server identity.
#[derive(Debug, Serialize, Deserialize)]
#[allow(dead_code)]
pub(crate) struct ServerIdentityProof {
/// Server name.
pub name: ServerName,
/// Server identity data.
pub data: HandshakeData,
}

View File

@@ -3,6 +3,7 @@
mod config; mod config;
mod error; mod error;
mod future; mod future;
mod prove;
pub mod state; pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder}; pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
@@ -18,19 +19,7 @@ use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig; use mpz_zk::ProverConfig as ZkProverConfig;
use webpki::anchor_from_trusted_cert; use webpki::anchor_from_trusted_cert;
use crate::{ use crate::{Role, context::build_mt_context, mux::attach_mux, tag::verify_tags};
Role,
commit::{
commit_records,
hash::prove_hash,
transcript::{TranscriptRefs, decode_transcript},
},
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt}; use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys}; use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
@@ -39,12 +28,9 @@ use serio::SinkExt;
use std::sync::Arc; use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName}; use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client}; use tls_client_async::{TlsConnection, bind_client};
use tls_core::msgs::enums::ContentType;
use tlsn_core::{ use tlsn_core::{
ProvePayload, connection::ServerName,
connection::{HandshakeData, ServerName}, transcript::{TlsTranscript, Transcript},
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret},
}; };
use tlsn_deap::Deap; use tlsn_deap::Deap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
@@ -115,22 +101,6 @@ impl Prover<state::Initialized> {
let mut keys = mpc_tls.alloc()?; let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked"); let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?; translate_keys(&mut keys, &vm_lock)?;
// Allocate for committing to plaintext.
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Prover);
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
zk_aes_ctr_sent.alloc(
&mut *vm_lock.zk(),
self.config.protocol_config().max_sent_data(),
)?;
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Prover);
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr_recv.alloc(
&mut *vm_lock.zk(),
self.config.protocol_config().max_recv_data(),
)?;
drop(vm_lock); drop(vm_lock);
debug!("setting up mpc-tls"); debug!("setting up mpc-tls");
@@ -146,8 +116,6 @@ impl Prover<state::Initialized> {
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
mpc_tls, mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys, keys,
vm, vm,
}, },
@@ -173,8 +141,6 @@ impl Prover<state::Setup> {
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mpc_tls, mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
keys, keys,
vm, vm,
.. ..
@@ -281,35 +247,6 @@ impl Prover<state::Setup> {
) )
.map_err(ProverError::zk)?; .map_err(ProverError::zk)?;
// Prove sent and received plaintext. Prover drops the proof
// output, as they trust themselves.
let (sent_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
let (recv_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(|e| {
if e.is_insufficient() {
ProverError::zk(format!("{e}. Attempted to prove more received data than was configured, increase `max_recv_data` in the config."))
} else {
ProverError::zk(e)
}
})?;
mux_fut mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk)) .poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?; .await?;
@@ -317,7 +254,6 @@ impl Prover<state::Setup> {
let transcript = tls_transcript let transcript = tls_transcript
.to_transcript() .to_transcript()
.expect("transcript is complete"); .expect("transcript is complete");
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
Ok(Prover { Ok(Prover {
config: self.config, config: self.config,
@@ -327,9 +263,9 @@ impl Prover<state::Setup> {
mux_fut, mux_fut,
ctx, ctx,
vm, vm,
keys,
tls_transcript, tls_transcript,
transcript, transcript,
transcript_refs,
}, },
}) })
} }
@@ -368,117 +304,24 @@ impl Prover<state::Committed> {
mux_fut, mux_fut,
ctx, ctx,
vm, vm,
keys,
tls_transcript, tls_transcript,
transcript, transcript,
transcript_refs,
.. ..
} = &mut self.state; } = &mut self.state;
let mut output = ProverOutput { let output = mux_fut
transcript_commitments: Vec::new(), .poll_with(prove::prove(
transcript_secrets: Vec::new(), ctx,
}; vm,
keys,
let partial_transcript = if let Some((sent, recv)) = config.reveal() { self.config.server_name(),
decode_transcript(vm, sent, recv, transcript_refs).map_err(ProverError::zk)?; transcript,
tls_transcript,
Some(transcript.to_partial(sent.clone(), recv.clone())) config,
} else { ))
None
};
let payload = ProvePayload {
handshake: config.server_identity().then(|| {
(
self.config.server_name().clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: partial_transcript,
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
// Send payload.
mux_fut
.poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
.await?; .await?;
let mut hash_commitments = None;
if let Some(commit_config) = config.transcript_commit() {
if commit_config.has_encoding() {
let hasher: &(dyn HashAlgorithm + Send + Sync) =
match *commit_config.encoding_hash_alg() {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ProverError::config(format!(
"unsupported hash algorithm for encoding commitment: {alg}"
)));
}
};
let (commitment, tree) = mux_fut
.poll_with(
encoding::receive(
ctx,
hasher,
transcript_refs,
|plaintext| vm.get_macs(plaintext).expect("reference is valid"),
commit_config.iter_encoding(),
)
.map_err(ProverError::commit),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if commit_config.has_hash() {
hash_commitments = Some(
prove_hash(
vm,
transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
.await?;
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
Ok(output) Ok(output)
} }

View File

@@ -2,7 +2,7 @@ use std::{error::Error, fmt};
use mpc_tls::MpcTlsError; use mpc_tls::MpcTlsError;
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError}; use crate::encoding::EncodingError;
/// Error for [`Prover`](crate::Prover). /// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -110,12 +110,6 @@ impl From<MpcTlsError> for ProverError {
} }
} }
impl From<ZkAesCtrError> for ProverError {
fn from(e: ZkAesCtrError) -> Self {
Self::new(ErrorKind::Zk, e)
}
}
impl From<EncodingError> for ProverError { impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self { fn from(e: EncodingError) -> Self {
Self::new(ErrorKind::Commit, e) Self::new(ErrorKind::Commit, e)

View File

@@ -0,0 +1,198 @@
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use serio::SinkExt;
use tlsn_core::{
ProveConfig, ProveRequest, ProverOutput,
connection::{HandshakeData, ServerName},
transcript::{
ContentType, Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret,
},
};
use crate::{
commit::{auth::prove_plaintext, hash::prove_hash, transcript::TranscriptRefs},
encoding::{self, MacStore},
prover::ProverError,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
server_name: &ServerName,
transcript: &Transcript,
tls_transcript: &TlsTranscript,
config: &ProveConfig,
) -> Result<ProverOutput, ProverError> {
let mut output = ProverOutput {
transcript_commitments: Vec::default(),
transcript_secrets: Vec::default(),
};
let request = ProveRequest {
handshake: config.server_identity().then(|| {
(
server_name.clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: config
.reveal()
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone())),
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
ctx.io_mut()
.send(request)
.await
.map_err(ProverError::from)?;
let mut auth_sent_ranges = RangeSet::default();
let mut auth_recv_ranges = RangeSet::default();
let (reveal_sent, reveal_recv) = config.reveal().cloned().unwrap_or_default();
auth_sent_ranges.union_mut(&reveal_sent);
auth_recv_ranges.union_mut(&reveal_recv);
if let Some(commit_config) = config.transcript_commit() {
commit_config
.iter_hash()
.for_each(|((direction, idx), _)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
});
commit_config
.iter_encoding()
.for_each(|(direction, idx)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
});
}
let mut zk_aes_sent = ZkAesCtr::new(
keys.client_write_key,
keys.client_write_iv,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let mut zk_aes_recv = ZkAesCtr::new(
keys.server_write_key,
keys.server_write_iv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let sent_refs = prove_plaintext(
vm,
&mut zk_aes_sent,
transcript.sent(),
&auth_sent_ranges,
&reveal_sent,
)
.map_err(ProverError::commit)?;
let recv_refs = prove_plaintext(
vm,
&mut zk_aes_recv,
transcript.received(),
&auth_recv_ranges,
&reveal_recv,
)
.map_err(ProverError::commit)?;
let transcript_refs = TranscriptRefs {
sent: sent_refs,
recv: recv_refs,
};
let hash_commitments = if let Some(commit_config) = config.transcript_commit()
&& commit_config.has_hash()
{
Some(
prove_hash(
vm,
&transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
)
} else {
None
};
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
if let Some(commit_config) = config.transcript_commit()
&& commit_config.has_encoding()
{
let mut sent_ranges = RangeSet::default();
let mut recv_ranges = RangeSet::default();
for (dir, idx) in commit_config.iter_encoding() {
match dir {
Direction::Sent => sent_ranges.union_mut(idx),
Direction::Received => recv_ranges.union_mut(idx),
}
}
let sent_map = transcript_refs
.sent
.index(&sent_ranges)
.expect("indices are valid");
let recv_map = transcript_refs
.recv
.index(&recv_ranges)
.expect("indices are valid");
let (commitment, tree) = encoding::receive(
ctx,
vm,
*commit_config.encoding_hash_alg(),
&sent_map,
&recv_map,
commit_config.iter_encoding(),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
Ok(output)
}

View File

@@ -9,10 +9,8 @@ use tlsn_deap::Deap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::{ use crate::{
commit::transcript::TranscriptRefs,
mux::{MuxControl, MuxFuture}, mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk}, prover::{Mpc, Zk},
zk_aes_ctr::ZkAesCtr,
}; };
/// Entry state /// Entry state
@@ -25,8 +23,6 @@ pub struct Setup {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader, pub(crate) mpc_tls: MpcTlsLeader,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys, pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>, pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
} }
@@ -39,9 +35,9 @@ pub struct Committed {
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
pub(crate) vm: Zk, pub(crate) vm: Zk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript, pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript, pub(crate) transcript: Transcript,
pub(crate) transcript_refs: TranscriptRefs,
} }
opaque_debug::implement!(Committed); opaque_debug::implement!(Committed);

View File

@@ -3,6 +3,7 @@
pub(crate) mod config; pub(crate) mod config;
mod error; mod error;
pub mod state; pub mod state;
mod verify;
use std::sync::Arc; use std::sync::Arc;
@@ -14,18 +15,7 @@ pub use tlsn_core::{
}; };
use crate::{ use crate::{
Role, Role, config::ProtocolConfig, context::build_mt_context, mux::attach_mux, tag::verify_tags,
commit::{
commit_records,
hash::verify_hash,
transcript::{TranscriptRefs, decode_transcript, verify_transcript},
},
config::ProtocolConfig,
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
}; };
use futures::{AsyncRead, AsyncWrite, TryFutureExt}; use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys}; use mpc_tls::{MpcTlsFollower, SessionKeys};
@@ -35,11 +25,9 @@ use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*; use mpz_vm_core::prelude::*;
use mpz_zk::VerifierConfig as ZkVerifierConfig; use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::stream::IoStreamExt; use serio::stream::IoStreamExt;
use tls_core::msgs::enums::ContentType;
use tlsn_core::{ use tlsn_core::{
ProvePayload,
connection::{ConnectionInfo, ServerName}, connection::{ConnectionInfo, ServerName},
transcript::{TlsTranscript, TranscriptCommitment}, transcript::TlsTranscript,
}; };
use tlsn_deap::Deap; use tlsn_deap::Deap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
@@ -114,23 +102,12 @@ impl Verifier<state::Initialized> {
}) })
.await?; .await?;
let delta = Delta::random(&mut rand::rng()); let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, ctx);
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, delta, ctx);
// Allocate resources for MPC-TLS in the VM. // Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?; let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked"); let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?; translate_keys(&mut keys, &vm_lock)?;
// Allocate for committing to plaintext.
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Verifier);
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
zk_aes_ctr_sent.alloc(&mut *vm_lock.zk(), protocol_config.max_sent_data())?;
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Verifier);
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr_recv.alloc(&mut *vm_lock.zk(), protocol_config.max_recv_data())?;
drop(vm_lock); drop(vm_lock);
debug!("setting up mpc-tls"); debug!("setting up mpc-tls");
@@ -145,10 +122,7 @@ impl Verifier<state::Initialized> {
state: state::Setup { state: state::Setup {
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
delta,
mpc_tls, mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys, keys,
vm, vm,
}, },
@@ -186,10 +160,7 @@ impl Verifier<state::Setup> {
let state::Setup { let state::Setup {
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
delta,
mpc_tls, mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
vm, vm,
keys, keys,
} = self.state; } = self.state;
@@ -230,27 +201,6 @@ impl Verifier<state::Setup> {
) )
.map_err(VerifierError::zk)?; .map_err(VerifierError::zk)?;
// Prepare for the prover to prove received plaintext.
let (sent_refs, sent_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
mux_fut mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk)) .poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?; .await?;
@@ -260,23 +210,16 @@ impl Verifier<state::Setup> {
// authenticated from the verifier's perspective. // authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?; tag_proof.verify().map_err(VerifierError::zk)?;
// Verify the plaintext proofs.
sent_proof.verify().map_err(VerifierError::zk)?;
recv_proof.verify().map_err(VerifierError::zk)?;
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
Ok(Verifier { Ok(Verifier {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Committed { state: state::Committed {
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
delta,
ctx, ctx,
vm, vm,
keys,
tls_transcript, tls_transcript,
transcript_refs,
}, },
}) })
} }
@@ -301,130 +244,34 @@ impl Verifier<state::Committed> {
let state::Committed { let state::Committed {
mux_fut, mux_fut,
ctx, ctx,
delta,
vm, vm,
keys,
tls_transcript, tls_transcript,
transcript_refs,
.. ..
} = &mut self.state; } = &mut self.state;
let ProvePayload { let cert_verifier = if let Some(root_store) = self.config.root_store() {
handshake,
transcript,
transcript_commit,
} = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let verifier = if let Some(root_store) = self.config.root_store() {
ServerCertVerifier::new(root_store).map_err(VerifierError::config)? ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
} else { } else {
ServerCertVerifier::mozilla() ServerCertVerifier::mozilla()
}; };
let server_name = if let Some((name, cert_data)) = handshake { let request = mux_fut
cert_data .poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.verify(
&verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
if let Some(partial_transcript) = &transcript {
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
// Check ranges.
if partial_transcript.len_sent() != sent_len
|| partial_transcript.len_received() != recv_len
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
decode_transcript(
vm,
partial_transcript.sent_authed(),
partial_transcript.received_authed(),
transcript_refs,
)
.map_err(VerifierError::zk)?;
}
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit {
if commit_config.encoding() {
let commitment = mux_fut
.poll_with(encoding::transfer(
ctx,
transcript_refs,
delta,
|plaintext| vm.get_keys(plaintext).expect("reference is valid"),
))
.await?;
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if commit_config.has_hash() {
hash_commitments = Some(
verify_hash(vm, transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(VerifierError::zk))
.await?; .await?;
// Verify revealed data. let output = mux_fut
if let Some(partial_transcript) = &transcript { .poll_with(verify::verify(
verify_transcript(vm, partial_transcript, transcript_refs) ctx,
.map_err(VerifierError::verify)?; vm,
} keys,
&cert_verifier,
tls_transcript,
request,
))
.await?;
if let Some(hash_commitments) = hash_commitments { Ok(output)
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript,
transcript_commitments,
})
} }
/// Closes the connection with the prover. /// Closes the connection with the prover.
@@ -447,11 +294,11 @@ impl Verifier<state::Committed> {
fn build_mpc_tls( fn build_mpc_tls(
config: &VerifierConfig, config: &VerifierConfig,
protocol_config: &ProtocolConfig, protocol_config: &ProtocolConfig,
delta: Delta,
ctx: Context, ctx: Context,
) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsFollower) { ) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsFollower) {
let mut rng = rand::rng(); let mut rng = rand::rng();
let delta = Delta::random(&mut rng);
let base_ot_send = mpz_ot::chou_orlandi::Sender::default(); let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default(); let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new( let rcot_send = mpz_ot::kos::Sender::new(

View File

@@ -1,4 +1,4 @@
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError}; use crate::encoding::EncodingError;
use mpc_tls::MpcTlsError; use mpc_tls::MpcTlsError;
use std::{error::Error, fmt}; use std::{error::Error, fmt};
@@ -110,12 +110,6 @@ impl From<MpcTlsError> for VerifierError {
} }
} }
impl From<ZkAesCtrError> for VerifierError {
fn from(e: ZkAesCtrError) -> Self {
Self::new(ErrorKind::Zk, e)
}
}
impl From<EncodingError> for VerifierError { impl From<EncodingError> for VerifierError {
fn from(e: EncodingError) -> Self { fn from(e: EncodingError) -> Self {
Self::new(ErrorKind::Commit, e) Self::new(ErrorKind::Commit, e)

View File

@@ -2,14 +2,9 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::mux::{MuxControl, MuxFuture};
commit::transcript::TranscriptRefs,
mux::{MuxControl, MuxFuture},
zk_aes_ctr::ZkAesCtr,
};
use mpc_tls::{MpcTlsFollower, SessionKeys}; use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context; use mpz_common::Context;
use mpz_memory_core::correlated::Delta;
use tlsn_core::transcript::TlsTranscript; use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap; use tlsn_deap::Deap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
@@ -28,10 +23,7 @@ opaque_debug::implement!(Initialized);
pub struct Setup { pub struct Setup {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) delta: Delta,
pub(crate) mpc_tls: MpcTlsFollower, pub(crate) mpc_tls: MpcTlsFollower,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys, pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>, pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
} }
@@ -40,11 +32,10 @@ pub struct Setup {
pub struct Committed { pub struct Committed {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) delta: Delta,
pub(crate) ctx: Context, pub(crate) ctx: Context,
pub(crate) vm: Zk, pub(crate) vm: Zk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript, pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript_refs: TranscriptRefs,
} }
opaque_debug::implement!(Committed); opaque_debug::implement!(Committed);

View File

@@ -0,0 +1,183 @@
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use tlsn_core::{
ProveRequest, VerifierOutput,
transcript::{
ContentType, Direction, PartialTranscript, Record, TlsTranscript, TranscriptCommitment,
},
webpki::ServerCertVerifier,
};
use crate::{
commit::{auth::verify_plaintext, hash::verify_hash, transcript::TranscriptRefs},
encoding::{self, KeyStore},
verifier::VerifierError,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
cert_verifier: &ServerCertVerifier,
tls_transcript: &TlsTranscript,
request: ProveRequest,
) -> Result<VerifierOutput, VerifierError> {
let ProveRequest {
handshake,
transcript,
transcript_commit,
} = request;
let ciphertext_sent = collect_ciphertext(tls_transcript.sent());
let ciphertext_recv = collect_ciphertext(tls_transcript.recv());
let has_reveal = transcript.is_some();
let transcript = if let Some(transcript) = transcript {
if transcript.len_sent() != ciphertext_sent.len()
|| transcript.len_received() != ciphertext_recv.len()
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
transcript
} else {
PartialTranscript::new(ciphertext_sent.len(), ciphertext_recv.len())
};
let server_name = if let Some((name, cert_data)) = handshake {
cert_data
.verify(
cert_verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
let mut auth_sent_ranges = RangeSet::default();
let mut auth_recv_ranges = RangeSet::default();
auth_sent_ranges.union_mut(transcript.sent_authed());
auth_recv_ranges.union_mut(transcript.received_authed());
if let Some(commit_config) = transcript_commit.as_ref() {
commit_config
.iter_hash()
.for_each(|(direction, idx, _)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
});
if let Some((sent, recv)) = commit_config.encoding() {
auth_sent_ranges.union_mut(sent);
auth_recv_ranges.union_mut(recv);
}
}
let mut zk_aes_sent = ZkAesCtr::new(
keys.client_write_key,
keys.client_write_iv,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let mut zk_aes_recv = ZkAesCtr::new(
keys.server_write_key,
keys.server_write_iv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let (sent_refs, sent_proof) = verify_plaintext(
vm,
&mut zk_aes_sent,
transcript.sent_unsafe(),
&ciphertext_sent,
&auth_sent_ranges,
transcript.sent_authed(),
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = verify_plaintext(
vm,
&mut zk_aes_recv,
transcript.received_unsafe(),
&ciphertext_recv,
&auth_recv_ranges,
transcript.received_authed(),
)
.map_err(VerifierError::zk)?;
let transcript_refs = TranscriptRefs {
sent: sent_refs,
recv: recv_refs,
};
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit.as_ref()
&& commit_config.has_hash()
{
hash_commitments = Some(
verify_hash(vm, &transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
vm.execute_all(ctx).await.map_err(VerifierError::zk)?;
sent_proof.verify().map_err(VerifierError::verify)?;
recv_proof.verify().map_err(VerifierError::verify)?;
if let Some(commit_config) = transcript_commit
&& let Some((sent, recv)) = commit_config.encoding()
{
let sent_map = transcript_refs
.sent
.index(sent)
.expect("ranges were authenticated");
let recv_map = transcript_refs
.recv
.index(recv)
.expect("ranges were authenticated");
let commitment = encoding::transfer(ctx, vm, &sent_map, &recv_map).await?;
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript: has_reveal.then_some(transcript),
transcript_commitments,
})
}
fn collect_ciphertext<'a>(records: impl IntoIterator<Item = &'a Record>) -> Vec<u8> {
let mut ciphertext = Vec::new();
records
.into_iter()
.filter(|record| record.typ == ContentType::ApplicationData)
.for_each(|record| {
ciphertext.extend_from_slice(&record.ciphertext);
});
ciphertext
}

View File

@@ -1,181 +1,226 @@
//! Zero-knowledge AES-CTR encryption. use std::{ops::Range, sync::Arc};
use cipher::{ use mpz_circuits::circuits::{AES128, xor};
Cipher, CipherError, Keystream,
aes::{Aes128, AesError},
};
use mpz_memory_core::{ use mpz_memory_core::{
Array, Vector, Array, MemoryExt, Vector, ViewExt,
binary::{Binary, U8}, binary::{Binary, U8},
}; };
use mpz_vm_core::{Vm, prelude::*}; use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::RangeSet;
use tlsn_core::transcript::Record;
use crate::Role; use crate::commit::transcript::ReferenceMap;
type Nonce = Array<U8, 8>;
type Ctr = Array<U8, 4>;
type Block = Array<U8, 16>;
const START_CTR: u32 = 2;
/// ZK AES-CTR encryption. /// ZK AES-CTR encryption.
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ZkAesCtr { pub(crate) struct ZkAesCtr {
role: Role, key: Array<U8, 16>,
aes: Aes128, iv: Array<U8, 4>,
state: State, records: Vec<(usize, RecordState)>,
total_len: usize,
} }
impl ZkAesCtr { impl ZkAesCtr {
/// Creates a new ZK AES-CTR instance. /// Creates a new instance.
pub(crate) fn new(role: Role) -> Self { pub(crate) fn new<'record>(
key: Array<U8, 16>,
iv: Array<U8, 4>,
records: impl IntoIterator<Item = &'record Record>,
) -> Self {
let mut pos = 0;
let mut record_state = Vec::new();
for record in records {
record_state.push((
pos,
RecordState {
explicit_nonce: Some(record.explicit_nonce.clone()),
explicit_nonce_ref: None,
range: pos..pos + record.ciphertext.len(),
},
));
pos += record.ciphertext.len();
}
Self { Self {
role, key,
aes: Aes128::default(), iv,
state: State::Init, records: record_state,
total_len: pos,
} }
} }
/// Returns the role. /// Allocates the plaintext for the provided ranges.
pub(crate) fn role(&self) -> &Role { ///
&self.role /// Returns a reference to the plaintext and the ciphertext.
} pub(crate) fn alloc_plaintext(
/// Allocates `len` bytes for encryption.
pub(crate) fn alloc(
&mut self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
len: usize, ranges: &RangeSet<usize>,
) -> Result<(), ZkAesCtrError> { ) -> Result<(ReferenceMap, ReferenceMap), ZkAesCtrError> {
let State::Init = self.state.take() else { let len = ranges.len();
Err(ErrorRepr::State {
reason: "must be in init state to allocate",
})?
};
// Round up to the nearest block size. if len > self.total_len {
let len = 16 * len.div_ceil(16); return Err(ZkAesCtrError(ErrorRepr::TranscriptBounds {
len,
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?; max: self.total_len,
let keystream = self.aes.alloc_keystream(vm, len)?; }));
match self.role {
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
} }
self.state = State::Ready { input, keystream }; let plaintext = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
let keystream = self.alloc_keystream(vm, ranges)?;
Ok(()) let mut builder = Call::builder(Arc::new(xor(len * 8))).arg(plaintext);
for slice in keystream {
builder = builder.arg(slice);
}
let call = builder.build().expect("call should be valid");
let ciphertext: Vector<U8> = vm.call(call).map_err(ZkAesCtrError::vm)?;
let mut pos = 0;
let plaintext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
}));
let mut pos = 0;
let ciphertext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
}));
Ok((plaintext, ciphertext))
} }
/// Sets the key and IV for the cipher. fn alloc_keystream(
pub(crate) fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
self.aes.set_key(key);
self.aes.set_iv(iv);
}
/// Proves the encryption of `len` bytes.
///
/// Here we only assign certain values in the VM but the actual proving
/// happens later when the plaintext is assigned and the VM is executed.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `len` - Length of the plaintext in bytes.
///
/// # Returns
///
/// A VM reference to the plaintext and the ciphertext.
pub(crate) fn encrypt(
&mut self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>, ranges: &RangeSet<usize>,
len: usize, ) -> Result<Vec<Vector<U8>>, ZkAesCtrError> {
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> { let mut keystream = Vec::new();
let State::Ready { input, keystream } = &mut self.state else {
Err(ErrorRepr::State {
reason: "must be in ready state to encrypt",
})?
};
let explicit_nonce: [u8; 8] = let mut range_iter = ranges.iter_ranges();
explicit_nonce let mut current_range = range_iter.next();
.try_into() for (pos, record) in self.records.iter_mut() {
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength { let pos = *pos;
expected: 8, let mut current_block = None;
actual: explicit_nonce.len(), loop {
})?; let Some(range) = current_range.take().or_else(|| range_iter.next()) else {
return Ok(keystream);
};
let block_count = len.div_ceil(16); if range.start >= record.range.end {
let padded_len = block_count * 16; current_range = Some(range);
let padding_len = padded_len - len; break;
}
if padded_len > input.len() { const BLOCK_SIZE: usize = 16;
Err(ErrorRepr::InsufficientPreprocessing { let block_num = (range.start - pos) / BLOCK_SIZE;
expected: padded_len, let block = if let Some((current_block_num, block)) = current_block.take()
actual: input.len(), && current_block_num == block_num
})? {
} block
} else {
let block = record.alloc_block(vm, self.key, self.iv, block_num)?;
let mut input = input.split_off(input.len() - padded_len); current_block = Some((block_num, block));
let keystream = keystream.consume(padded_len)?;
let mut output = keystream.apply(vm, input)?;
// Assign counter block inputs. block
let mut ctr = START_CTR..; };
keystream.assign(vm, explicit_nonce, move || {
ctr.next().expect("range is unbounded").to_be_bytes()
})?;
// Assign zeroes to the padding. let start = (range.start - pos) % BLOCK_SIZE;
if padding_len > 0 { let end = (start + range.len()).min(BLOCK_SIZE);
let padding = input.split_off(input.len() - padding_len); let len = end - start;
// To simplify the impl, we don't mark the padding as public, that's why only
// the prover assigns it. keystream.push(block.get(start..end).expect("range is checked"));
if let Role::Prover = self.role {
vm.assign(padding, vec![0; padding_len]) // If the range is larger than a block, process the tail.
.map_err(ZkAesCtrError::vm)?; if range.len() > BLOCK_SIZE {
current_range = Some(range.start + len..range.end);
}
} }
vm.commit(padding).map_err(ZkAesCtrError::vm)?;
output.truncate(len);
} }
Ok((input, output)) unreachable!("plaintext length was checked");
} }
} }
enum State { #[derive(Debug)]
Init, struct RecordState {
Ready { explicit_nonce: Option<Vec<u8>>,
input: Vector<U8>, range: Range<usize>,
keystream: Keystream<Nonce, Ctr, Block>, explicit_nonce_ref: Option<Vector<U8>>,
},
Error,
} }
impl State { impl RecordState {
fn take(&mut self) -> Self { fn alloc_explicit_nonce(
std::mem::replace(self, State::Error) &mut self,
} vm: &mut dyn Vm<Binary>,
} ) -> Result<Vector<U8>, ZkAesCtrError> {
if let Some(explicit_nonce) = self.explicit_nonce_ref {
Ok(explicit_nonce)
} else {
const EXPLICIT_NONCE_LEN: usize = 8;
let explicit_nonce_ref = vm
.alloc_vec::<U8>(EXPLICIT_NONCE_LEN)
.map_err(ZkAesCtrError::vm)?;
vm.mark_public(explicit_nonce_ref)
.map_err(ZkAesCtrError::vm)?;
vm.assign(
explicit_nonce_ref,
self.explicit_nonce
.take()
.expect("explicit nonce only set once"),
)
.map_err(ZkAesCtrError::vm)?;
vm.commit(explicit_nonce_ref).map_err(ZkAesCtrError::vm)?;
impl std::fmt::Debug for State { self.explicit_nonce_ref = Some(explicit_nonce_ref);
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Ok(explicit_nonce_ref)
match self {
State::Init => write!(f, "Init"),
State::Ready { .. } => write!(f, "Ready"),
State::Error => write!(f, "Error"),
} }
} }
fn alloc_block(
&mut self,
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
block: usize,
) -> Result<Vector<U8>, ZkAesCtrError> {
let explicit_nonce = self.alloc_explicit_nonce(vm)?;
let ctr: Array<U8, 4> = vm.alloc().map_err(ZkAesCtrError::vm)?;
vm.mark_public(ctr).map_err(ZkAesCtrError::vm)?;
const START_CTR: u32 = 2;
vm.assign(ctr, (START_CTR + block as u32).to_be_bytes())
.map_err(ZkAesCtrError::vm)?;
vm.commit(ctr).map_err(ZkAesCtrError::vm)?;
let block: Array<U8, 16> = vm
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(ctr)
.build()
.expect("call should be valid"),
)
.map_err(ZkAesCtrError::vm)?;
Ok(Vector::from(block))
}
} }
/// Error for [`ZkAesCtr`]. /// Error for [`ZkAesCtr`].
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error(transparent)] #[error(transparent)]
pub struct ZkAesCtrError(#[from] ErrorRepr); pub(crate) struct ZkAesCtrError(#[from] ErrorRepr);
impl ZkAesCtrError { impl ZkAesCtrError {
fn vm<E>(err: E) -> Self fn vm<E>(err: E) -> Self
@@ -184,35 +229,13 @@ impl ZkAesCtrError {
{ {
Self(ErrorRepr::Vm(err.into())) Self(ErrorRepr::Vm(err.into()))
} }
pub fn is_insufficient(&self) -> bool {
matches!(self.0, ErrorRepr::InsufficientPreprocessing { .. })
}
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("zk aes error")] #[error("zk aes error")]
enum ErrorRepr { enum ErrorRepr {
#[error("invalid state: {reason}")]
State { reason: &'static str },
#[error("cipher error: {0}")]
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("vm error: {0}")] #[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>), Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("invalid explicit nonce length: expected {expected}, got {actual}")] #[error("transcript bounds exceeded: {len} > {max}")]
ExplicitNonceLength { expected: usize, actual: usize }, TranscriptBounds { len: usize, max: usize },
#[error("insufficient preprocessing: expected {expected}, got {actual}")]
InsufficientPreprocessing { expected: usize, actual: usize },
}
impl From<AesError> for ZkAesCtrError {
fn from(err: AesError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}
impl From<CipherError> for ZkAesCtrError {
fn from(err: CipherError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
} }

View File

@@ -1,9 +1,14 @@
use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::RangeSet;
use tlsn::{ use tlsn::{
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore}, config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName, connection::ServerName,
hash::{HashAlgId, HashProvider},
prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::{TranscriptCommitConfig, TranscriptCommitment}, transcript::{
Direction, TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind,
TranscriptSecret,
},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig}, verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
}; };
use tlsn_server_fixture::bind; use tlsn_server_fixture::bind;
@@ -86,9 +91,25 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
let mut builder = TranscriptCommitConfig::builder(prover.transcript()); let mut builder = TranscriptCommitConfig::builder(prover.transcript());
// Commit to everything for kind in [
builder.commit_sent(&(0..sent_tx_len)).unwrap(); TranscriptCommitmentKind::Encoding,
builder.commit_recv(&(0..recv_tx_len)).unwrap(); TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
] {
builder
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
.unwrap();
builder
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
.unwrap();
}
let transcript_commit = builder.build().unwrap(); let transcript_commit = builder.build().unwrap();
@@ -102,9 +123,52 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
builder.transcript_commit(transcript_commit); builder.transcript_commit(transcript_commit);
let config = builder.build().unwrap(); let config = builder.build().unwrap();
let transcript = prover.transcript().clone();
prover.prove(&config).await.unwrap(); let output = prover.prove(&config).await.unwrap();
prover.close().await.unwrap(); prover.close().await.unwrap();
let encoding_tree = output
.transcript_secrets
.iter()
.find_map(|secret| {
if let TranscriptSecret::Encoding(tree) = secret {
Some(tree)
} else {
None
}
})
.unwrap();
let encoding_commitment = output
.transcript_commitments
.iter()
.find_map(|commitment| {
if let TranscriptCommitment::Encoding(commitment) = commitment {
Some(commitment)
} else {
None
}
})
.unwrap();
let prove_sent = RangeSet::from(1..sent_tx_len - 1);
let prove_recv = RangeSet::from(1..recv_tx_len - 1);
let idxs = [
(Direction::Sent, prove_sent.clone()),
(Direction::Received, prove_recv.clone()),
];
let proof = encoding_tree.proof(idxs.iter()).unwrap();
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
encoding_commitment,
transcript.sent(),
transcript.received(),
)
.unwrap();
assert_eq!(auth_sent, prove_sent);
assert_eq!(auth_recv, prove_recv);
} }
#[instrument(skip(socket))] #[instrument(skip(socket))]
@@ -125,14 +189,21 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(soc
.unwrap(), .unwrap(),
); );
let mut verifier = verifier
.setup(socket.compat())
.await
.unwrap()
.run()
.await
.unwrap();
let VerifierOutput { let VerifierOutput {
server_name, server_name,
transcript, transcript,
transcript_commitments, transcript_commitments,
} = verifier } = verifier.verify(&VerifyConfig::default()).await.unwrap();
.verify(socket.compat(), &VerifyConfig::default())
.await verifier.close().await.unwrap();
.unwrap();
let transcript = transcript.unwrap(); let transcript = transcript.unwrap();