feat(tlsn): partial plaintext auth (#1006)

Co-authored-by: th4s <th4s@metavoid.xyz>
This commit is contained in:
sinu.eth
2025-10-09 11:22:23 -07:00
committed by GitHub
parent df8d79c152
commit 2884be17e0
22 changed files with 1542 additions and 1294 deletions

631
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -66,19 +66,19 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }

View File

@@ -190,10 +190,10 @@ pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum VerifyConfigBuilderErrorRepr {}
/// Payload sent to the verifier.
/// Request to prove statements about the connection.
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvePayload {
pub struct ProveRequest {
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.

View File

@@ -2,7 +2,7 @@
use std::{collections::HashSet, fmt};
use rangeset::ToRangeSet;
use rangeset::{ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
@@ -114,7 +114,19 @@ impl TranscriptCommitConfig {
/// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest {
encoding: self.has_encoding,
encoding: self.has_encoding.then(|| {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (dir, idx) in self.iter_encoding() {
match dir {
Direction::Sent => sent.union_mut(idx),
Direction::Received => recv.union_mut(idx),
}
}
(sent, recv)
}),
hash: self
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -289,14 +301,14 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest {
encoding: bool,
encoding: Option<(RangeSet<usize>, RangeSet<usize>)>,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
}
impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested.
pub fn encoding(&self) -> bool {
self.encoding
pub fn has_encoding(&self) -> bool {
self.encoding.is_some()
}
/// Returns `true` if a hash commitment is requested.
@@ -308,6 +320,11 @@ impl TranscriptCommitRequest {
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
self.hash.iter()
}
/// Returns the ranges of the encoding commitments.
pub fn encoding(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.encoding.as_ref()
}
}
#[cfg(test)]

View File

@@ -31,6 +31,7 @@ web-spawn = { workspace = true, optional = true }
mpz-common = { workspace = true }
mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true }
mpz-garble-core = { workspace = true }
mpz-hash = { workspace = true }

View File

@@ -1,116 +1,5 @@
//! Plaintext commitment and proof of encryption.
pub(crate) mod auth;
pub(crate) mod hash;
pub(crate) mod transcript;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
DecodeFutureTyped, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, prelude::*};
use tlsn_core::transcript::Record;
use crate::{
Role,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
/// Commits the plaintext of the provided records, returning a proof of
/// encryption.
///
/// Writes the plaintext VM reference to the provided records.
pub(crate) fn commit_records<'record>(
vm: &mut dyn Vm<Binary>,
aes: &mut ZkAesCtr,
records: impl IntoIterator<Item = &'record Record>,
) -> Result<(Vec<Vector<U8>>, RecordProof), RecordProofError> {
let mut plaintexts = Vec::new();
let mut ciphertexts = Vec::new();
for record in records {
let (plaintext_ref, ciphertext_ref) = aes
.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())
.map_err(ErrorRepr::Aes)?;
if let Role::Prover = aes.role() {
let Some(plaintext) = record.plaintext.clone() else {
return Err(ErrorRepr::MissingPlaintext.into());
};
vm.assign(plaintext_ref, plaintext)
.map_err(RecordProofError::vm)?;
}
vm.commit(plaintext_ref).map_err(RecordProofError::vm)?;
let ciphertext = vm.decode(ciphertext_ref).map_err(RecordProofError::vm)?;
plaintexts.push(plaintext_ref);
ciphertexts.push((ciphertext, record.ciphertext.clone()));
}
Ok((plaintexts, RecordProof { ciphertexts }))
}
/// Proof of encryption.
#[derive(Debug)]
#[must_use]
#[allow(clippy::type_complexity)]
pub(crate) struct RecordProof {
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
}
impl RecordProof {
/// Verifies the proof.
pub(crate) fn verify(self) -> Result<(), RecordProofError> {
let Self { ciphertexts } = self;
for (mut ciphertext, expected) in ciphertexts {
let ciphertext = ciphertext
.try_recv()
.map_err(RecordProofError::vm)?
.ok_or_else(|| ErrorRepr::NotDecoded)?;
if ciphertext != expected {
return Err(ErrorRepr::InvalidCiphertext.into());
}
}
Ok(())
}
}
/// Error for [`RecordProof`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct RecordProofError(#[from] ErrorRepr);
impl RecordProofError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
pub(crate) fn is_insufficient(&self) -> bool {
match &self.0 {
ErrorRepr::Aes(err) => err.is_insufficient(),
_ => false,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("record proof error: {0}")]
enum ErrorRepr {
#[error("VM error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk aes error: {0}")]
Aes(ZkAesCtrError),
#[error("plaintext is missing")]
MissingPlaintext,
#[error("ciphertext was not decoded")]
NotDecoded,
#[error("ciphertext does not match expected")]
InvalidCiphertext,
}

View File

@@ -0,0 +1,166 @@
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{DecodeFutureTyped, MemoryExt, ViewExt, binary::Binary};
use mpz_vm_core::Vm;
use rangeset::{Difference, RangeSet, Subset};
use crate::{
commit::transcript::ReferenceMap,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
pub(crate) fn prove_plaintext(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
plaintext: &[u8],
ranges: &RangeSet<usize>,
public: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
assert!(public.is_subset(ranges), "public is not a subset of ranges");
if ranges.is_empty() {
return Ok(ReferenceMap::default());
}
let (plaintext_map, ciphertext_map) = zk_aes
.alloc_plaintext(vm, ranges)
.map_err(ErrorRepr::ZkAesCtr)?;
for (range, chunk) in plaintext_map
.index(&ranges.difference(public))
.expect("map contains all ranges")
.iter()
{
vm.mark_private(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (range, chunk) in plaintext_map
.index(public)
.expect("map contains all ranges")
.iter()
{
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (_, chunk) in ciphertext_map.iter() {
drop(vm.decode(*chunk).map_err(PlaintextAuthError::vm)?);
}
Ok(plaintext_map)
}
pub(crate) fn verify_plaintext(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
plaintext: &[u8],
ciphertext: &[u8],
ranges: &RangeSet<usize>,
public: &RangeSet<usize>,
) -> Result<(ReferenceMap, PlaintextProof), PlaintextAuthError> {
assert!(public.is_subset(ranges), "public is not a subset of ranges");
if ranges.is_empty() {
return Ok((
ReferenceMap::default(),
PlaintextProof {
ciphertexts: vec![],
},
));
}
let (plaintext_map, ciphertext_map) = zk_aes
.alloc_plaintext(vm, ranges)
.map_err(ErrorRepr::ZkAesCtr)?;
for (_, chunk) in plaintext_map
.index(&ranges.difference(public))
.expect("map contains all ranges")
.iter()
{
vm.mark_blind(*chunk).map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
for (range, chunk) in plaintext_map
.index(public)
.expect("map contains all ranges")
.iter()
{
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
vm.assign(*chunk, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
}
let mut ciphertexts = Vec::new();
for (range, chunk) in ciphertext_map
.index(ranges)
.expect("map contains all ranges")
.iter()
{
ciphertexts.push((
ciphertext[range].to_vec(),
vm.decode(*chunk).map_err(PlaintextAuthError::vm)?,
));
}
Ok((plaintext_map, PlaintextProof { ciphertexts }))
}
#[derive(Debug, thiserror::Error)]
#[error("plaintext authentication error: {0}")]
pub(crate) struct PlaintextAuthError(#[from] ErrorRepr);
impl PlaintextAuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk aes ctr error: {0}")]
ZkAesCtr(ZkAesCtrError),
#[error("missing decoding")]
MissingDecoding,
#[error("invalid ciphertext")]
InvalidCiphertext,
}
#[must_use]
pub(crate) struct PlaintextProof {
// (expected, actual)
#[allow(clippy::type_complexity)]
ciphertexts: Vec<(Vec<u8>, DecodeFutureTyped<BitVec, Vec<u8>>)>,
}
impl PlaintextProof {
pub(crate) fn verify(self) -> Result<(), PlaintextAuthError> {
let Self {
ciphertexts: ciphertext,
} = self;
for (expected, mut actual) in ciphertext {
let actual = actual
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(PlaintextAuthError(ErrorRepr::MissingDecoding))?;
if actual != expected {
return Err(PlaintextAuthError(ErrorRepr::InvalidCiphertext));
}
}
Ok(())
}
}

View File

@@ -149,9 +149,15 @@ fn hash_commit_inner(
hasher
};
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
hasher.update(&plaintext);
let refs = match direction {
Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
@@ -164,9 +170,14 @@ fn hash_commit_inner(
hasher
};
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
let refs = match direction {
Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
hasher
.update(vm, &plaintext)
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
}
hasher

View File

@@ -1,211 +1,205 @@
use mpz_memory_core::{
MemoryExt, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError};
use rangeset::{Intersection, RangeSet};
use tlsn_core::transcript::{Direction, PartialTranscript};
use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8};
use rangeset::RangeSet;
pub(crate) type ReferenceMap = RangeMap<Vector<U8>>;
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
pub(crate) struct TranscriptRefs {
sent: Vec<Vector<U8>>,
recv: Vec<Vector<U8>>,
pub(crate) sent: ReferenceMap,
pub(crate) recv: ReferenceMap,
}
impl TranscriptRefs {
pub(crate) fn new(sent: Vec<Vector<U8>>, recv: Vec<Vector<U8>>) -> Self {
Self { sent, recv }
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> {
map: Vec<(usize, T)>,
}
impl<T> Default for RangeMap<T>
where
T: Item,
{
fn default() -> Self {
Self { map: Vec::new() }
}
}
/// Returns the sent plaintext references.
pub(crate) fn sent(&self) -> &[Vector<U8>] {
&self.sent
}
impl<T> RangeMap<T>
where
T: Item,
{
pub(crate) fn new(map: Vec<(usize, T)>) -> Self {
let mut pos = 0;
for (idx, item) in &map {
assert!(
*idx >= pos,
"items must be sorted by index and non-overlapping"
);
/// Returns the received plaintext references.
pub(crate) fn recv(&self) -> &[Vector<U8>] {
&self.recv
}
/// Returns the transcript lengths.
pub(crate) fn len(&self) -> (usize, usize) {
let sent = self.sent.iter().map(|v| v.len()).sum();
let recv = self.recv.iter().map(|v| v.len()).sum();
(sent, recv)
}
/// Returns VM references for the given direction and index, otherwise
/// `None` if the index is out of bounds.
pub(crate) fn get(
&self,
direction: Direction,
idx: &RangeSet<usize>,
) -> Option<Vec<Vector<U8>>> {
if idx.is_empty() {
return Some(Vec::new());
pos = *idx + item.length();
}
let refs = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
Self { map }
}
/// Returns the length of the map.
pub(crate) fn len(&self) -> usize {
self.map.iter().map(|(_, item)| item.length()).sum()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (Range<usize>, &T)> {
self.map
.iter()
.map(|(idx, item)| (*idx..*idx + item.length(), item))
}
pub(crate) fn get(&self, range: Range<usize>) -> Option<T::Slice<'_>> {
if range.start >= range.end {
return None;
}
// Find the item with the greatest start index <= range.start
let pos = match self.map.binary_search_by(|(idx, _)| idx.cmp(&range.start)) {
Ok(i) => i,
Err(0) => return None,
Err(i) => i - 1,
};
// Computes the transcript range for each reference.
let mut start = 0;
let mut slice_iter = refs.iter().map(move |slice| {
let out = (slice, start..start + slice.len());
start += slice.len();
out
});
let (base, item) = &self.map[pos];
let mut slices = Vec::new();
let (mut slice, mut slice_range) = slice_iter.next()?;
for range in idx.iter_ranges() {
loop {
if let Some(intersection) = slice_range.intersection(&range) {
let start = intersection.start - slice_range.start;
let end = intersection.end - slice_range.start;
slices.push(slice.get(start..end).expect("range should be in bounds"));
}
item.slice(range.start - *base..range.end - *base)
}
// Proceed to next range if the current slice extends beyond. Otherwise, proceed
// to the next slice.
if range.end <= slice_range.end {
break;
} else {
(slice, slice_range) = slice_iter.next()?;
}
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
let mut map = Vec::new();
for idx in idx.iter_ranges() {
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
Ok(i) => i,
Err(0) => return None,
Err(i) => i - 1,
};
let (base, item) = self.map.get(pos)?;
if idx.start < *base || idx.end > *base + item.length() {
return None;
}
let start = idx.start - *base;
let end = start + idx.len();
map.push((
idx.start,
item.slice(start..end)
.expect("slice length is checked")
.into(),
));
}
Some(slices)
Some(Self { map })
}
}
/// Decodes the transcript.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
sent: &RangeSet<usize>,
recv: &RangeSet<usize>,
refs: &TranscriptRefs,
) -> Result<(), VmError> {
let sent_refs = refs.get(Direction::Sent, sent).expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, recv)
.expect("index is in bounds");
impl<T> FromIterator<(usize, T)> for RangeMap<T>
where
T: Item,
{
fn from_iter<I: IntoIterator<Item = (usize, T)>>(items: I) -> Self {
let mut pos = 0;
let mut map = Vec::new();
for (idx, item) in items {
assert!(
idx >= pos,
"items must be sorted by index and non-overlapping"
);
for slice in sent_refs.into_iter().chain(recv_refs) {
// Drop the future, we don't need it.
drop(vm.decode(slice)?);
pos = idx + item.length();
map.push((idx, item));
}
Self { map }
}
Ok(())
}
/// Verifies a partial transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
transcript: &PartialTranscript,
refs: &TranscriptRefs,
) -> Result<(), InconsistentTranscript> {
let sent_refs = refs
.get(Direction::Sent, transcript.sent_authed())
.expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, transcript.received_authed())
.expect("index is in bounds");
pub(crate) trait Item: Sized {
type Slice<'a>: Into<Self>
where
Self: 'a;
let mut authenticated_data = Vec::new();
for data in sent_refs.into_iter().chain(recv_refs) {
let plaintext = vm
.get(data)
.expect("reference is valid")
.expect("plaintext is decoded");
authenticated_data.extend_from_slice(&plaintext);
}
fn length(&self) -> usize;
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in transcript.sent_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.sent_unsafe()[range]);
}
for range in transcript.received_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(InconsistentTranscript {});
}
Ok(())
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>>;
}
/// Error for [`verify_transcript`].
#[derive(Debug, thiserror::Error)]
#[error("inconsistent transcript")]
pub(crate) struct InconsistentTranscript {}
impl Item for Vector<U8> {
type Slice<'a> = Vector<U8>;
fn length(&self) -> usize {
self.len()
}
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
self.get(range)
}
}
#[cfg(test)]
mod tests {
use super::TranscriptRefs;
use mpz_memory_core::{FromRaw, Slice, Vector, binary::U8};
use rangeset::RangeSet;
use std::ops::Range;
use tlsn_core::transcript::Direction;
use super::*;
// TRANSCRIPT_REFS:
//
// 48..96 -> 6 slots
// 112..176 -> 8 slots
// 240..288 -> 6 slots
// 352..392 -> 5 slots
// 440..480 -> 5 slots
const TRANSCRIPT_REFS: &[Range<usize>] = &[48..96, 112..176, 240..288, 352..392, 440..480];
impl Item for Range<usize> {
type Slice<'a> = Range<usize>;
const IDXS: &[Range<usize>] = &[0..4, 5..10, 14..16, 16..28];
fn length(&self) -> usize {
self.end - self.start
}
// 1. Take slots 0..4, 4 slots -> 48..80 (4)
// 2. Take slots 5..10, 5 slots -> 88..96 (1) + 112..144 (4)
// 3. Take slots 14..16, 2 slots -> 240..256 (2)
// 4. Take slots 16..28, 12 slots -> 256..288 (4) + 352..392 (5) + 440..464 (3)
//
// 5. Merge slots 240..256 and 256..288 => 240..288 and get EXPECTED_REFS
const EXPECTED_REFS: &[Range<usize>] =
&[48..80, 88..96, 112..144, 240..288, 352..392, 440..464];
fn slice(&self, range: Range<usize>) -> Option<Self> {
if range.end > self.end - self.start {
return None;
}
#[test]
fn test_transcript_refs_get() {
let transcript_refs: Vec<Vector<U8>> = TRANSCRIPT_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
let transcript_refs = TranscriptRefs {
sent: transcript_refs.clone(),
recv: transcript_refs,
};
let vm_refs = transcript_refs
.get(Direction::Sent, &RangeSet::from(IDXS))
.unwrap();
let expected_refs: Vec<Vector<U8>> = EXPECTED_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
assert_eq!(
vm_refs.len(),
expected_refs.len(),
"Length of actual and expected refs are not equal"
);
for (&expected, actual) in expected_refs.iter().zip(vm_refs) {
assert_eq!(expected, actual);
Some(range.start + self.start..range.end + self.start)
}
}
#[test]
fn test_range_map() {
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
assert_eq!(map.get(0..4), Some(10..14));
assert_eq!(map.get(10..14), Some(20..24));
assert_eq!(map.get(20..22), Some(30..32));
assert_eq!(map.get(0..2), Some(10..12));
assert_eq!(map.get(11..13), Some(21..23));
assert_eq!(map.get(0..10), None);
assert_eq!(map.get(10..20), None);
assert_eq!(map.get(20..30), None);
}
#[test]
fn test_range_map_index() {
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
let idx = RangeSet::from([0..4, 10..14, 20..22]);
assert_eq!(map.index(&idx), Some(map.clone()));
let idx = RangeSet::from(25..30);
assert_eq!(map.index(&idx), None);
let idx = RangeSet::from(15..20);
assert_eq!(map.index(&idx), None);
let idx = RangeSet::from([1..3, 11..12, 13..14, 21..22]);
assert_eq!(
map.index(&idx),
Some(RangeMap::from_iter([
(1, 11..13),
(11, 21..22),
(13, 23..24),
(21, 31..32)
]))
);
}
}

View File

@@ -13,7 +13,7 @@ use rangeset::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
hash::HashAlgorithm,
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{
Direction,
encoding::{
@@ -23,7 +23,7 @@ use tlsn_core::{
},
};
use crate::commit::transcript::TranscriptRefs;
use crate::commit::transcript::{Item, RangeMap, ReferenceMap};
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;
@@ -34,145 +34,130 @@ struct Encodings {
recv: Vec<u8>,
}
/// Transfers the encodings using the provided seed and keys.
///
/// The keys must be consistent with the global delta used in the encodings.
pub(crate) async fn transfer<'a>(
/// Transfers encodings for the provided plaintext ranges.
pub(crate) async fn transfer<K: KeyStore>(
ctx: &mut Context,
refs: &TranscriptRefs,
delta: &Delta,
f: impl Fn(Vector<U8>) -> &'a [Key],
store: &K,
sent: &ReferenceMap,
recv: &ReferenceMap,
) -> Result<EncodingCommitment, EncodingError> {
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes());
let secret = EncoderSecret::new(rand::rng().random(), store.delta().as_block().to_bytes());
let encoder = new_encoder(&secret);
let sent_keys: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
let recv_keys: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
// Collects the encodings for the provided plaintext ranges.
fn collect_encodings(
encoder: &impl Encoder,
store: &impl KeyStore,
direction: Direction,
map: &ReferenceMap,
) -> Vec<u8> {
let mut encodings = Vec::with_capacity(map.len() * ENCODING_SIZE);
for (range, chunk) in map.iter() {
let start = encodings.len();
encoder.encode_range(direction, range, &mut encodings);
let keys = store
.get_keys(*chunk)
.expect("keys are present for provided plaintext ranges");
encodings[start..]
.iter_mut()
.zip(keys.iter().flat_map(|key| key.as_block().as_bytes()))
.for_each(|(encoding, key)| {
*encoding ^= *key;
});
}
encodings
}
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
let encodings = Encodings {
sent: collect_encodings(&encoder, store, Direction::Sent, sent),
recv: collect_encodings(&encoder, store, Direction::Received, recv),
};
let mut sent_encoding = Vec::with_capacity(sent_keys.len());
let mut recv_encoding = Vec::with_capacity(recv_keys.len());
encoder.encode_range(
Direction::Sent,
0..sent_keys.len() / ENCODING_SIZE,
&mut sent_encoding,
);
encoder.encode_range(
Direction::Received,
0..recv_keys.len() / ENCODING_SIZE,
&mut recv_encoding,
);
sent_encoding
.iter_mut()
.zip(sent_keys)
.for_each(|(enc, key)| *enc ^= key);
recv_encoding
.iter_mut()
.zip(recv_keys)
.for_each(|(enc, key)| *enc ^= key);
// Set frame limit and add some extra bytes cushion room.
let (sent, recv) = refs.len();
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
ctx.io_mut()
.with_limit(frame_limit)
.send(Encodings {
sent: sent_encoding,
recv: recv_encoding,
})
.await?;
let frame_limit = ctx
.io()
.limit()
.saturating_add(encodings.sent.len() + encodings.recv.len());
ctx.io_mut().with_limit(frame_limit).send(encodings).await?;
let root = ctx.io_mut().expect_next().await?;
ctx.io_mut().send(secret.clone()).await?;
Ok(EncodingCommitment {
root,
secret: secret.clone(),
})
Ok(EncodingCommitment { root, secret })
}
/// Receives the encodings using the provided MACs.
///
/// The MACs must be consistent with the global delta used in the encodings.
pub(crate) async fn receive<'a>(
/// Receives and commits to the encodings for the provided plaintext ranges.
pub(crate) async fn receive<M: MacStore>(
ctx: &mut Context,
hasher: &(dyn HashAlgorithm + Send + Sync),
refs: &TranscriptRefs,
f: impl Fn(Vector<U8>) -> &'a [Mac],
store: &M,
hash_alg: HashAlgId,
sent: &ReferenceMap,
recv: &ReferenceMap,
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
// Set frame limit and add some extra bytes cushion room.
let (sent, recv) = refs.len();
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
let hasher: &(dyn HashAlgorithm + Send + Sync) = match hash_alg {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ErrorRepr::UnsupportedHashAlgorithm(alg).into());
}
};
let Encodings { mut sent, mut recv } =
ctx.io_mut().with_limit(frame_limit).expect_next().await?;
let (sent_len, recv_len) = (sent.len(), recv.len());
let frame_limit = ctx
.io()
.limit()
.saturating_add(ENCODING_SIZE * (sent_len + recv_len));
let encodings: Encodings = ctx.io_mut().with_limit(frame_limit).expect_next().await?;
let sent_macs: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
let recv_macs: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
if sent.len() != sent_macs.len() {
if encodings.sent.len() != sent_len * ENCODING_SIZE {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent,
expected: sent_macs.len(),
got: sent.len(),
expected: sent_len,
got: encodings.sent.len() / ENCODING_SIZE,
}
.into());
}
if recv.len() != recv_macs.len() {
if encodings.recv.len() != recv_len * ENCODING_SIZE {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received,
expected: recv_macs.len(),
got: recv.len(),
expected: recv_len,
got: encodings.recv.len() / ENCODING_SIZE,
}
.into());
}
sent.iter_mut()
.zip(sent_macs)
.for_each(|(enc, mac)| *enc ^= mac);
recv.iter_mut()
.zip(recv_macs)
.for_each(|(enc, mac)| *enc ^= mac);
// Collects a map of plaintext ranges to their encodings.
fn collect_map(
store: &impl MacStore,
mut encodings: Vec<u8>,
map: &ReferenceMap,
) -> RangeMap<EncodingSlice> {
let mut encoding_map = Vec::new();
let mut pos = 0;
for (range, chunk) in map.iter() {
let macs = store
.get_macs(*chunk)
.expect("MACs are present for provided plaintext ranges");
let encoding = &mut encodings[pos..pos + range.len() * ENCODING_SIZE];
encoding
.iter_mut()
.zip(macs.iter().flat_map(|mac| mac.as_bytes()))
.for_each(|(encoding, mac)| {
*encoding ^= *mac;
});
let provider = Provider { sent, recv };
encoding_map.push((range.start, EncodingSlice::from(&(*encoding))));
pos += range.len() * ENCODING_SIZE;
}
RangeMap::new(encoding_map)
}
let provider = Provider {
sent: collect_map(store, encodings.sent, sent),
recv: collect_map(store, encodings.recv, recv),
};
let tree = EncodingTree::new(hasher, idxs, &provider)?;
let root = tree.root();
@@ -185,10 +170,36 @@ pub(crate) async fn receive<'a>(
Ok((commitment, tree))
}
pub(crate) trait KeyStore {
fn delta(&self) -> &Delta;
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
}
impl KeyStore for crate::verifier::Zk {
fn delta(&self) -> &Delta {
crate::verifier::Zk::delta(self)
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
pub(crate) trait MacStore {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
}
impl MacStore for crate::prover::Zk {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
recv: Vec<u8>,
sent: RangeMap<EncodingSlice>,
recv: RangeMap<EncodingSlice>,
}
impl EncodingProvider for Provider {
@@ -203,19 +214,39 @@ impl EncodingProvider for Provider {
Direction::Received => &self.recv,
};
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
let encoding = encodings.get(range).ok_or(EncodingProviderError)?;
if end > encodings.len() {
return Err(EncodingProviderError);
}
dest.extend_from_slice(&encodings[start..end]);
dest.extend_from_slice(encoding);
Ok(())
}
}
#[derive(Debug)]
struct EncodingSlice(Vec<u8>);
impl From<&[u8]> for EncodingSlice {
fn from(value: &[u8]) -> Self {
Self(value.to_vec())
}
}
impl Item for EncodingSlice {
type Slice<'a>
= &'a [u8]
where
Self: 'a;
fn length(&self) -> usize {
self.0.len() / ENCODING_SIZE
}
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
self.0
.get(range.start * ENCODING_SIZE..range.end * ENCODING_SIZE)
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
@@ -234,6 +265,8 @@ enum ErrorRepr {
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
#[error("unsupported hash algorithm: {0}")]
UnsupportedHashAlgorithm(HashAlgId),
}
impl From<std::io::Error> for EncodingError {

View File

@@ -9,7 +9,6 @@ pub mod config;
pub(crate) mod context;
pub(crate) mod encoding;
pub(crate) mod ghash;
pub(crate) mod msg;
pub(crate) mod mux;
pub mod prover;
pub(crate) mod tag;

View File

@@ -1,15 +0,0 @@
//! Message types.
use serde::{Deserialize, Serialize};
use tlsn_core::connection::{HandshakeData, ServerName};
/// Message sent from Prover to Verifier to prove the server identity.
#[derive(Debug, Serialize, Deserialize)]
#[allow(dead_code)]
pub(crate) struct ServerIdentityProof {
/// Server name.
pub name: ServerName,
/// Server identity data.
pub data: HandshakeData,
}

View File

@@ -3,6 +3,7 @@
mod config;
mod error;
mod future;
mod prove;
pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
@@ -18,19 +19,7 @@ use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig;
use webpki::anchor_from_trusted_cert;
use crate::{
Role,
commit::{
commit_records,
hash::prove_hash,
transcript::{TranscriptRefs, decode_transcript},
},
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
use crate::{Role, context::build_mt_context, mux::attach_mux, tag::verify_tags};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
@@ -39,12 +28,9 @@ use serio::SinkExt;
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
connection::{HandshakeData, ServerName},
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret},
connection::ServerName,
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
@@ -115,22 +101,6 @@ impl Prover<state::Initialized> {
let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?;
// Allocate for committing to plaintext.
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Prover);
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
zk_aes_ctr_sent.alloc(
&mut *vm_lock.zk(),
self.config.protocol_config().max_sent_data(),
)?;
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Prover);
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr_recv.alloc(
&mut *vm_lock.zk(),
self.config.protocol_config().max_recv_data(),
)?;
drop(vm_lock);
debug!("setting up mpc-tls");
@@ -146,8 +116,6 @@ impl Prover<state::Initialized> {
mux_ctrl,
mux_fut,
mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
},
@@ -173,8 +141,6 @@ impl Prover<state::Setup> {
mux_ctrl,
mut mux_fut,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
keys,
vm,
..
@@ -281,35 +247,6 @@ impl Prover<state::Setup> {
)
.map_err(ProverError::zk)?;
// Prove sent and received plaintext. Prover drops the proof
// output, as they trust themselves.
let (sent_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
let (recv_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(|e| {
if e.is_insufficient() {
ProverError::zk(format!("{e}. Attempted to prove more received data than was configured, increase `max_recv_data` in the config."))
} else {
ProverError::zk(e)
}
})?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
@@ -317,7 +254,6 @@ impl Prover<state::Setup> {
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
Ok(Prover {
config: self.config,
@@ -327,9 +263,9 @@ impl Prover<state::Setup> {
mux_fut,
ctx,
vm,
keys,
tls_transcript,
transcript,
transcript_refs,
},
})
}
@@ -368,117 +304,24 @@ impl Prover<state::Committed> {
mux_fut,
ctx,
vm,
keys,
tls_transcript,
transcript,
transcript_refs,
..
} = &mut self.state;
let mut output = ProverOutput {
transcript_commitments: Vec::new(),
transcript_secrets: Vec::new(),
};
let partial_transcript = if let Some((sent, recv)) = config.reveal() {
decode_transcript(vm, sent, recv, transcript_refs).map_err(ProverError::zk)?;
Some(transcript.to_partial(sent.clone(), recv.clone()))
} else {
None
};
let payload = ProvePayload {
handshake: config.server_identity().then(|| {
(
self.config.server_name().clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: partial_transcript,
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
// Send payload.
mux_fut
.poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
let output = mux_fut
.poll_with(prove::prove(
ctx,
vm,
keys,
self.config.server_name(),
transcript,
tls_transcript,
config,
))
.await?;
let mut hash_commitments = None;
if let Some(commit_config) = config.transcript_commit() {
if commit_config.has_encoding() {
let hasher: &(dyn HashAlgorithm + Send + Sync) =
match *commit_config.encoding_hash_alg() {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ProverError::config(format!(
"unsupported hash algorithm for encoding commitment: {alg}"
)));
}
};
let (commitment, tree) = mux_fut
.poll_with(
encoding::receive(
ctx,
hasher,
transcript_refs,
|plaintext| vm.get_macs(plaintext).expect("reference is valid"),
commit_config.iter_encoding(),
)
.map_err(ProverError::commit),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if commit_config.has_hash() {
hash_commitments = Some(
prove_hash(
vm,
transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
.await?;
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
Ok(output)
}

View File

@@ -2,7 +2,7 @@ use std::{error::Error, fmt};
use mpc_tls::MpcTlsError;
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use crate::encoding::EncodingError;
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]
@@ -110,12 +110,6 @@ impl From<MpcTlsError> for ProverError {
}
}
impl From<ZkAesCtrError> for ProverError {
fn from(e: ZkAesCtrError) -> Self {
Self::new(ErrorKind::Zk, e)
}
}
impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self {
Self::new(ErrorKind::Commit, e)

View File

@@ -0,0 +1,198 @@
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use serio::SinkExt;
use tlsn_core::{
ProveConfig, ProveRequest, ProverOutput,
connection::{HandshakeData, ServerName},
transcript::{
ContentType, Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret,
},
};
use crate::{
commit::{auth::prove_plaintext, hash::prove_hash, transcript::TranscriptRefs},
encoding::{self, MacStore},
prover::ProverError,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
server_name: &ServerName,
transcript: &Transcript,
tls_transcript: &TlsTranscript,
config: &ProveConfig,
) -> Result<ProverOutput, ProverError> {
let mut output = ProverOutput {
transcript_commitments: Vec::default(),
transcript_secrets: Vec::default(),
};
let request = ProveRequest {
handshake: config.server_identity().then(|| {
(
server_name.clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: config
.reveal()
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone())),
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
ctx.io_mut()
.send(request)
.await
.map_err(ProverError::from)?;
let mut auth_sent_ranges = RangeSet::default();
let mut auth_recv_ranges = RangeSet::default();
let (reveal_sent, reveal_recv) = config.reveal().cloned().unwrap_or_default();
auth_sent_ranges.union_mut(&reveal_sent);
auth_recv_ranges.union_mut(&reveal_recv);
if let Some(commit_config) = config.transcript_commit() {
commit_config
.iter_hash()
.for_each(|((direction, idx), _)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
});
commit_config
.iter_encoding()
.for_each(|(direction, idx)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
});
}
let mut zk_aes_sent = ZkAesCtr::new(
keys.client_write_key,
keys.client_write_iv,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let mut zk_aes_recv = ZkAesCtr::new(
keys.server_write_key,
keys.server_write_iv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let sent_refs = prove_plaintext(
vm,
&mut zk_aes_sent,
transcript.sent(),
&auth_sent_ranges,
&reveal_sent,
)
.map_err(ProverError::commit)?;
let recv_refs = prove_plaintext(
vm,
&mut zk_aes_recv,
transcript.received(),
&auth_recv_ranges,
&reveal_recv,
)
.map_err(ProverError::commit)?;
let transcript_refs = TranscriptRefs {
sent: sent_refs,
recv: recv_refs,
};
let hash_commitments = if let Some(commit_config) = config.transcript_commit()
&& commit_config.has_hash()
{
Some(
prove_hash(
vm,
&transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
)
} else {
None
};
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
if let Some(commit_config) = config.transcript_commit()
&& commit_config.has_encoding()
{
let mut sent_ranges = RangeSet::default();
let mut recv_ranges = RangeSet::default();
for (dir, idx) in commit_config.iter_encoding() {
match dir {
Direction::Sent => sent_ranges.union_mut(idx),
Direction::Received => recv_ranges.union_mut(idx),
}
}
let sent_map = transcript_refs
.sent
.index(&sent_ranges)
.expect("indices are valid");
let recv_map = transcript_refs
.recv
.index(&recv_ranges)
.expect("indices are valid");
let (commitment, tree) = encoding::receive(
ctx,
vm,
*commit_config.encoding_hash_alg(),
&sent_map,
&recv_map,
commit_config.iter_encoding(),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
Ok(output)
}

View File

@@ -9,10 +9,8 @@ use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::{
commit::transcript::TranscriptRefs,
mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk},
zk_aes_ctr::ZkAesCtr,
};
/// Entry state
@@ -25,8 +23,6 @@ pub struct Setup {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
}
@@ -39,9 +35,9 @@ pub struct Committed {
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
pub(crate) vm: Zk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
pub(crate) transcript_refs: TranscriptRefs,
}
opaque_debug::implement!(Committed);

View File

@@ -3,6 +3,7 @@
pub(crate) mod config;
mod error;
pub mod state;
mod verify;
use std::sync::Arc;
@@ -14,18 +15,7 @@ pub use tlsn_core::{
};
use crate::{
Role,
commit::{
commit_records,
hash::verify_hash,
transcript::{TranscriptRefs, decode_transcript, verify_transcript},
},
config::ProtocolConfig,
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
Role, config::ProtocolConfig, context::build_mt_context, mux::attach_mux, tag::verify_tags,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys};
@@ -35,11 +25,9 @@ use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::stream::IoStreamExt;
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
connection::{ConnectionInfo, ServerName},
transcript::{TlsTranscript, TranscriptCommitment},
transcript::TlsTranscript,
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
@@ -114,23 +102,12 @@ impl Verifier<state::Initialized> {
})
.await?;
let delta = Delta::random(&mut rand::rng());
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, delta, ctx);
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, ctx);
// Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?;
// Allocate for committing to plaintext.
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Verifier);
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
zk_aes_ctr_sent.alloc(&mut *vm_lock.zk(), protocol_config.max_sent_data())?;
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Verifier);
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr_recv.alloc(&mut *vm_lock.zk(), protocol_config.max_recv_data())?;
drop(vm_lock);
debug!("setting up mpc-tls");
@@ -145,10 +122,7 @@ impl Verifier<state::Initialized> {
state: state::Setup {
mux_ctrl,
mux_fut,
delta,
mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
},
@@ -186,10 +160,7 @@ impl Verifier<state::Setup> {
let state::Setup {
mux_ctrl,
mut mux_fut,
delta,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
vm,
keys,
} = self.state;
@@ -230,27 +201,6 @@ impl Verifier<state::Setup> {
)
.map_err(VerifierError::zk)?;
// Prepare for the prover to prove received plaintext.
let (sent_refs, sent_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?;
@@ -260,23 +210,16 @@ impl Verifier<state::Setup> {
// authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?;
// Verify the plaintext proofs.
sent_proof.verify().map_err(VerifierError::zk)?;
recv_proof.verify().map_err(VerifierError::zk)?;
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
Ok(Verifier {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
delta,
ctx,
vm,
keys,
tls_transcript,
transcript_refs,
},
})
}
@@ -301,130 +244,34 @@ impl Verifier<state::Committed> {
let state::Committed {
mux_fut,
ctx,
delta,
vm,
keys,
tls_transcript,
transcript_refs,
..
} = &mut self.state;
let ProvePayload {
handshake,
transcript,
transcript_commit,
} = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let verifier = if let Some(root_store) = self.config.root_store() {
let cert_verifier = if let Some(root_store) = self.config.root_store() {
ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
} else {
ServerCertVerifier::mozilla()
};
let server_name = if let Some((name, cert_data)) = handshake {
cert_data
.verify(
&verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
if let Some(partial_transcript) = &transcript {
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
// Check ranges.
if partial_transcript.len_sent() != sent_len
|| partial_transcript.len_received() != recv_len
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
decode_transcript(
vm,
partial_transcript.sent_authed(),
partial_transcript.received_authed(),
transcript_refs,
)
.map_err(VerifierError::zk)?;
}
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit {
if commit_config.encoding() {
let commitment = mux_fut
.poll_with(encoding::transfer(
ctx,
transcript_refs,
delta,
|plaintext| vm.get_keys(plaintext).expect("reference is valid"),
))
.await?;
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if commit_config.has_hash() {
hash_commitments = Some(
verify_hash(vm, transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(VerifierError::zk))
let request = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
// Verify revealed data.
if let Some(partial_transcript) = &transcript {
verify_transcript(vm, partial_transcript, transcript_refs)
.map_err(VerifierError::verify)?;
}
let output = mux_fut
.poll_with(verify::verify(
ctx,
vm,
keys,
&cert_verifier,
tls_transcript,
request,
))
.await?;
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript,
transcript_commitments,
})
Ok(output)
}
/// Closes the connection with the prover.
@@ -447,11 +294,11 @@ impl Verifier<state::Committed> {
fn build_mpc_tls(
config: &VerifierConfig,
protocol_config: &ProtocolConfig,
delta: Delta,
ctx: Context,
) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsFollower) {
let mut rng = rand::rng();
let delta = Delta::random(&mut rng);
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new(

View File

@@ -1,4 +1,4 @@
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use crate::encoding::EncodingError;
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
@@ -110,12 +110,6 @@ impl From<MpcTlsError> for VerifierError {
}
}
impl From<ZkAesCtrError> for VerifierError {
fn from(e: ZkAesCtrError) -> Self {
Self::new(ErrorKind::Zk, e)
}
}
impl From<EncodingError> for VerifierError {
fn from(e: EncodingError) -> Self {
Self::new(ErrorKind::Commit, e)

View File

@@ -2,14 +2,9 @@
use std::sync::Arc;
use crate::{
commit::transcript::TranscriptRefs,
mux::{MuxControl, MuxFuture},
zk_aes_ctr::ZkAesCtr,
};
use crate::mux::{MuxControl, MuxFuture};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_memory_core::correlated::Delta;
use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
@@ -28,10 +23,7 @@ opaque_debug::implement!(Initialized);
pub struct Setup {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) delta: Delta,
pub(crate) mpc_tls: MpcTlsFollower,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
}
@@ -40,11 +32,10 @@ pub struct Setup {
pub struct Committed {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) delta: Delta,
pub(crate) ctx: Context,
pub(crate) vm: Zk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript_refs: TranscriptRefs,
}
opaque_debug::implement!(Committed);

View File

@@ -0,0 +1,183 @@
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use tlsn_core::{
ProveRequest, VerifierOutput,
transcript::{
ContentType, Direction, PartialTranscript, Record, TlsTranscript, TranscriptCommitment,
},
webpki::ServerCertVerifier,
};
use crate::{
commit::{auth::verify_plaintext, hash::verify_hash, transcript::TranscriptRefs},
encoding::{self, KeyStore},
verifier::VerifierError,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
cert_verifier: &ServerCertVerifier,
tls_transcript: &TlsTranscript,
request: ProveRequest,
) -> Result<VerifierOutput, VerifierError> {
let ProveRequest {
handshake,
transcript,
transcript_commit,
} = request;
let ciphertext_sent = collect_ciphertext(tls_transcript.sent());
let ciphertext_recv = collect_ciphertext(tls_transcript.recv());
let has_reveal = transcript.is_some();
let transcript = if let Some(transcript) = transcript {
if transcript.len_sent() != ciphertext_sent.len()
|| transcript.len_received() != ciphertext_recv.len()
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
transcript
} else {
PartialTranscript::new(ciphertext_sent.len(), ciphertext_recv.len())
};
let server_name = if let Some((name, cert_data)) = handshake {
cert_data
.verify(
cert_verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
let mut auth_sent_ranges = RangeSet::default();
let mut auth_recv_ranges = RangeSet::default();
auth_sent_ranges.union_mut(transcript.sent_authed());
auth_recv_ranges.union_mut(transcript.received_authed());
if let Some(commit_config) = transcript_commit.as_ref() {
commit_config
.iter_hash()
.for_each(|(direction, idx, _)| match direction {
Direction::Sent => auth_sent_ranges.union_mut(idx),
Direction::Received => auth_recv_ranges.union_mut(idx),
});
if let Some((sent, recv)) = commit_config.encoding() {
auth_sent_ranges.union_mut(sent);
auth_recv_ranges.union_mut(recv);
}
}
let mut zk_aes_sent = ZkAesCtr::new(
keys.client_write_key,
keys.client_write_iv,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let mut zk_aes_recv = ZkAesCtr::new(
keys.server_write_key,
keys.server_write_iv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
);
let (sent_refs, sent_proof) = verify_plaintext(
vm,
&mut zk_aes_sent,
transcript.sent_unsafe(),
&ciphertext_sent,
&auth_sent_ranges,
transcript.sent_authed(),
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = verify_plaintext(
vm,
&mut zk_aes_recv,
transcript.received_unsafe(),
&ciphertext_recv,
&auth_recv_ranges,
transcript.received_authed(),
)
.map_err(VerifierError::zk)?;
let transcript_refs = TranscriptRefs {
sent: sent_refs,
recv: recv_refs,
};
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit.as_ref()
&& commit_config.has_hash()
{
hash_commitments = Some(
verify_hash(vm, &transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
vm.execute_all(ctx).await.map_err(VerifierError::zk)?;
sent_proof.verify().map_err(VerifierError::verify)?;
recv_proof.verify().map_err(VerifierError::verify)?;
if let Some(commit_config) = transcript_commit
&& let Some((sent, recv)) = commit_config.encoding()
{
let sent_map = transcript_refs
.sent
.index(sent)
.expect("ranges were authenticated");
let recv_map = transcript_refs
.recv
.index(recv)
.expect("ranges were authenticated");
let commitment = encoding::transfer(ctx, vm, &sent_map, &recv_map).await?;
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript: has_reveal.then_some(transcript),
transcript_commitments,
})
}
fn collect_ciphertext<'a>(records: impl IntoIterator<Item = &'a Record>) -> Vec<u8> {
let mut ciphertext = Vec::new();
records
.into_iter()
.filter(|record| record.typ == ContentType::ApplicationData)
.for_each(|record| {
ciphertext.extend_from_slice(&record.ciphertext);
});
ciphertext
}

View File

@@ -1,181 +1,226 @@
//! Zero-knowledge AES-CTR encryption.
use std::{ops::Range, sync::Arc};
use cipher::{
Cipher, CipherError, Keystream,
aes::{Aes128, AesError},
};
use mpz_circuits::circuits::{AES128, xor};
use mpz_memory_core::{
Array, Vector,
Array, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, prelude::*};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::RangeSet;
use tlsn_core::transcript::Record;
use crate::Role;
type Nonce = Array<U8, 8>;
type Ctr = Array<U8, 4>;
type Block = Array<U8, 16>;
const START_CTR: u32 = 2;
use crate::commit::transcript::ReferenceMap;
/// ZK AES-CTR encryption.
#[derive(Debug)]
pub(crate) struct ZkAesCtr {
role: Role,
aes: Aes128,
state: State,
key: Array<U8, 16>,
iv: Array<U8, 4>,
records: Vec<(usize, RecordState)>,
total_len: usize,
}
impl ZkAesCtr {
/// Creates a new ZK AES-CTR instance.
pub(crate) fn new(role: Role) -> Self {
/// Creates a new instance.
pub(crate) fn new<'record>(
key: Array<U8, 16>,
iv: Array<U8, 4>,
records: impl IntoIterator<Item = &'record Record>,
) -> Self {
let mut pos = 0;
let mut record_state = Vec::new();
for record in records {
record_state.push((
pos,
RecordState {
explicit_nonce: Some(record.explicit_nonce.clone()),
explicit_nonce_ref: None,
range: pos..pos + record.ciphertext.len(),
},
));
pos += record.ciphertext.len();
}
Self {
role,
aes: Aes128::default(),
state: State::Init,
key,
iv,
records: record_state,
total_len: pos,
}
}
/// Returns the role.
pub(crate) fn role(&self) -> &Role {
&self.role
}
/// Allocates `len` bytes for encryption.
pub(crate) fn alloc(
/// Allocates the plaintext for the provided ranges.
///
/// Returns a reference to the plaintext and the ciphertext.
pub(crate) fn alloc_plaintext(
&mut self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<(), ZkAesCtrError> {
let State::Init = self.state.take() else {
Err(ErrorRepr::State {
reason: "must be in init state to allocate",
})?
};
ranges: &RangeSet<usize>,
) -> Result<(ReferenceMap, ReferenceMap), ZkAesCtrError> {
let len = ranges.len();
// Round up to the nearest block size.
let len = 16 * len.div_ceil(16);
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
let keystream = self.aes.alloc_keystream(vm, len)?;
match self.role {
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
if len > self.total_len {
return Err(ZkAesCtrError(ErrorRepr::TranscriptBounds {
len,
max: self.total_len,
}));
}
self.state = State::Ready { input, keystream };
let plaintext = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
let keystream = self.alloc_keystream(vm, ranges)?;
Ok(())
let mut builder = Call::builder(Arc::new(xor(len * 8))).arg(plaintext);
for slice in keystream {
builder = builder.arg(slice);
}
let call = builder.build().expect("call should be valid");
let ciphertext: Vector<U8> = vm.call(call).map_err(ZkAesCtrError::vm)?;
let mut pos = 0;
let plaintext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
}));
let mut pos = 0;
let ciphertext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
}));
Ok((plaintext, ciphertext))
}
/// Sets the key and IV for the cipher.
pub(crate) fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
self.aes.set_key(key);
self.aes.set_iv(iv);
}
/// Proves the encryption of `len` bytes.
///
/// Here we only assign certain values in the VM but the actual proving
/// happens later when the plaintext is assigned and the VM is executed.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `len` - Length of the plaintext in bytes.
///
/// # Returns
///
/// A VM reference to the plaintext and the ciphertext.
pub(crate) fn encrypt(
fn alloc_keystream(
&mut self,
vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
let State::Ready { input, keystream } = &mut self.state else {
Err(ErrorRepr::State {
reason: "must be in ready state to encrypt",
})?
};
ranges: &RangeSet<usize>,
) -> Result<Vec<Vector<U8>>, ZkAesCtrError> {
let mut keystream = Vec::new();
let explicit_nonce: [u8; 8] =
explicit_nonce
.try_into()
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength {
expected: 8,
actual: explicit_nonce.len(),
})?;
let mut range_iter = ranges.iter_ranges();
let mut current_range = range_iter.next();
for (pos, record) in self.records.iter_mut() {
let pos = *pos;
let mut current_block = None;
loop {
let Some(range) = current_range.take().or_else(|| range_iter.next()) else {
return Ok(keystream);
};
let block_count = len.div_ceil(16);
let padded_len = block_count * 16;
let padding_len = padded_len - len;
if range.start >= record.range.end {
current_range = Some(range);
break;
}
if padded_len > input.len() {
Err(ErrorRepr::InsufficientPreprocessing {
expected: padded_len,
actual: input.len(),
})?
}
const BLOCK_SIZE: usize = 16;
let block_num = (range.start - pos) / BLOCK_SIZE;
let block = if let Some((current_block_num, block)) = current_block.take()
&& current_block_num == block_num
{
block
} else {
let block = record.alloc_block(vm, self.key, self.iv, block_num)?;
let mut input = input.split_off(input.len() - padded_len);
let keystream = keystream.consume(padded_len)?;
let mut output = keystream.apply(vm, input)?;
current_block = Some((block_num, block));
// Assign counter block inputs.
let mut ctr = START_CTR..;
keystream.assign(vm, explicit_nonce, move || {
ctr.next().expect("range is unbounded").to_be_bytes()
})?;
block
};
// Assign zeroes to the padding.
if padding_len > 0 {
let padding = input.split_off(input.len() - padding_len);
// To simplify the impl, we don't mark the padding as public, that's why only
// the prover assigns it.
if let Role::Prover = self.role {
vm.assign(padding, vec![0; padding_len])
.map_err(ZkAesCtrError::vm)?;
let start = (range.start - pos) % BLOCK_SIZE;
let end = (start + range.len()).min(BLOCK_SIZE);
let len = end - start;
keystream.push(block.get(start..end).expect("range is checked"));
// If the range is larger than a block, process the tail.
if range.len() > BLOCK_SIZE {
current_range = Some(range.start + len..range.end);
}
}
vm.commit(padding).map_err(ZkAesCtrError::vm)?;
output.truncate(len);
}
Ok((input, output))
unreachable!("plaintext length was checked");
}
}
enum State {
Init,
Ready {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
},
Error,
#[derive(Debug)]
struct RecordState {
explicit_nonce: Option<Vec<u8>>,
range: Range<usize>,
explicit_nonce_ref: Option<Vector<U8>>,
}
impl State {
fn take(&mut self) -> Self {
std::mem::replace(self, State::Error)
}
}
impl RecordState {
fn alloc_explicit_nonce(
&mut self,
vm: &mut dyn Vm<Binary>,
) -> Result<Vector<U8>, ZkAesCtrError> {
if let Some(explicit_nonce) = self.explicit_nonce_ref {
Ok(explicit_nonce)
} else {
const EXPLICIT_NONCE_LEN: usize = 8;
let explicit_nonce_ref = vm
.alloc_vec::<U8>(EXPLICIT_NONCE_LEN)
.map_err(ZkAesCtrError::vm)?;
vm.mark_public(explicit_nonce_ref)
.map_err(ZkAesCtrError::vm)?;
vm.assign(
explicit_nonce_ref,
self.explicit_nonce
.take()
.expect("explicit nonce only set once"),
)
.map_err(ZkAesCtrError::vm)?;
vm.commit(explicit_nonce_ref).map_err(ZkAesCtrError::vm)?;
impl std::fmt::Debug for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
State::Init => write!(f, "Init"),
State::Ready { .. } => write!(f, "Ready"),
State::Error => write!(f, "Error"),
self.explicit_nonce_ref = Some(explicit_nonce_ref);
Ok(explicit_nonce_ref)
}
}
fn alloc_block(
&mut self,
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
block: usize,
) -> Result<Vector<U8>, ZkAesCtrError> {
let explicit_nonce = self.alloc_explicit_nonce(vm)?;
let ctr: Array<U8, 4> = vm.alloc().map_err(ZkAesCtrError::vm)?;
vm.mark_public(ctr).map_err(ZkAesCtrError::vm)?;
const START_CTR: u32 = 2;
vm.assign(ctr, (START_CTR + block as u32).to_be_bytes())
.map_err(ZkAesCtrError::vm)?;
vm.commit(ctr).map_err(ZkAesCtrError::vm)?;
let block: Array<U8, 16> = vm
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(ctr)
.build()
.expect("call should be valid"),
)
.map_err(ZkAesCtrError::vm)?;
Ok(Vector::from(block))
}
}
/// Error for [`ZkAesCtr`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ZkAesCtrError(#[from] ErrorRepr);
pub(crate) struct ZkAesCtrError(#[from] ErrorRepr);
impl ZkAesCtrError {
fn vm<E>(err: E) -> Self
@@ -184,35 +229,13 @@ impl ZkAesCtrError {
{
Self(ErrorRepr::Vm(err.into()))
}
pub fn is_insufficient(&self) -> bool {
matches!(self.0, ErrorRepr::InsufficientPreprocessing { .. })
}
}
#[derive(Debug, thiserror::Error)]
#[error("zk aes error")]
enum ErrorRepr {
#[error("invalid state: {reason}")]
State { reason: &'static str },
#[error("cipher error: {0}")]
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
ExplicitNonceLength { expected: usize, actual: usize },
#[error("insufficient preprocessing: expected {expected}, got {actual}")]
InsufficientPreprocessing { expected: usize, actual: usize },
}
impl From<AesError> for ZkAesCtrError {
fn from(err: AesError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}
impl From<CipherError> for ZkAesCtrError {
fn from(err: CipherError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
#[error("transcript bounds exceeded: {len} > {max}")]
TranscriptBounds { len: usize, max: usize },
}

View File

@@ -1,9 +1,14 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::RangeSet;
use tlsn::{
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
hash::{HashAlgId, HashProvider},
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::{TranscriptCommitConfig, TranscriptCommitment},
transcript::{
Direction, TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind,
TranscriptSecret,
},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_server_fixture::bind;
@@ -86,9 +91,25 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
let mut builder = TranscriptCommitConfig::builder(prover.transcript());
// Commit to everything
builder.commit_sent(&(0..sent_tx_len)).unwrap();
builder.commit_recv(&(0..recv_tx_len)).unwrap();
for kind in [
TranscriptCommitmentKind::Encoding,
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
] {
builder
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
.unwrap();
builder
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
.unwrap();
}
let transcript_commit = builder.build().unwrap();
@@ -102,9 +123,52 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
builder.transcript_commit(transcript_commit);
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
let transcript = prover.transcript().clone();
let output = prover.prove(&config).await.unwrap();
prover.close().await.unwrap();
let encoding_tree = output
.transcript_secrets
.iter()
.find_map(|secret| {
if let TranscriptSecret::Encoding(tree) = secret {
Some(tree)
} else {
None
}
})
.unwrap();
let encoding_commitment = output
.transcript_commitments
.iter()
.find_map(|commitment| {
if let TranscriptCommitment::Encoding(commitment) = commitment {
Some(commitment)
} else {
None
}
})
.unwrap();
let prove_sent = RangeSet::from(1..sent_tx_len - 1);
let prove_recv = RangeSet::from(1..recv_tx_len - 1);
let idxs = [
(Direction::Sent, prove_sent.clone()),
(Direction::Received, prove_recv.clone()),
];
let proof = encoding_tree.proof(idxs.iter()).unwrap();
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
encoding_commitment,
transcript.sent(),
transcript.received(),
)
.unwrap();
assert_eq!(auth_sent, prove_sent);
assert_eq!(auth_recv, prove_recv);
}
#[instrument(skip(socket))]
@@ -125,14 +189,21 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(soc
.unwrap(),
);
let mut verifier = verifier
.setup(socket.compat())
.await
.unwrap()
.run()
.await
.unwrap();
let VerifierOutput {
server_name,
transcript,
transcript_commitments,
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await
.unwrap();
} = verifier.verify(&VerifyConfig::default()).await.unwrap();
verifier.close().await.unwrap();
let transcript = transcript.unwrap();