mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-07 22:03:58 -05:00
feat(tlsn): partial plaintext auth (#1006)
Co-authored-by: th4s <th4s@metavoid.xyz>
This commit is contained in:
631
Cargo.lock
generated
631
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
26
Cargo.toml
26
Cargo.toml
@@ -66,19 +66,19 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
|
||||
tlsn-wasm = { path = "crates/wasm" }
|
||||
tlsn = { path = "crates/tlsn" }
|
||||
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "8dce54e" }
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "f30e07c" }
|
||||
|
||||
rangeset = { version = "0.2" }
|
||||
serio = { version = "0.2" }
|
||||
|
||||
@@ -190,10 +190,10 @@ pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum VerifyConfigBuilderErrorRepr {}
|
||||
|
||||
/// Payload sent to the verifier.
|
||||
/// Request to prove statements about the connection.
|
||||
#[doc(hidden)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProvePayload {
|
||||
pub struct ProveRequest {
|
||||
/// Handshake data.
|
||||
pub handshake: Option<(ServerName, HandshakeData)>,
|
||||
/// Transcript data.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use std::{collections::HashSet, fmt};
|
||||
|
||||
use rangeset::ToRangeSet;
|
||||
use rangeset::{ToRangeSet, UnionMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -114,7 +114,19 @@ impl TranscriptCommitConfig {
|
||||
/// Returns a request for the transcript commitments.
|
||||
pub fn to_request(&self) -> TranscriptCommitRequest {
|
||||
TranscriptCommitRequest {
|
||||
encoding: self.has_encoding,
|
||||
encoding: self.has_encoding.then(|| {
|
||||
let mut sent = RangeSet::default();
|
||||
let mut recv = RangeSet::default();
|
||||
|
||||
for (dir, idx) in self.iter_encoding() {
|
||||
match dir {
|
||||
Direction::Sent => sent.union_mut(idx),
|
||||
Direction::Received => recv.union_mut(idx),
|
||||
}
|
||||
}
|
||||
|
||||
(sent, recv)
|
||||
}),
|
||||
hash: self
|
||||
.iter_hash()
|
||||
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
|
||||
@@ -289,14 +301,14 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
|
||||
/// Request to compute transcript commitments.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TranscriptCommitRequest {
|
||||
encoding: bool,
|
||||
encoding: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
|
||||
}
|
||||
|
||||
impl TranscriptCommitRequest {
|
||||
/// Returns `true` if an encoding commitment is requested.
|
||||
pub fn encoding(&self) -> bool {
|
||||
self.encoding
|
||||
pub fn has_encoding(&self) -> bool {
|
||||
self.encoding.is_some()
|
||||
}
|
||||
|
||||
/// Returns `true` if a hash commitment is requested.
|
||||
@@ -308,6 +320,11 @@ impl TranscriptCommitRequest {
|
||||
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
|
||||
self.hash.iter()
|
||||
}
|
||||
|
||||
/// Returns the ranges of the encoding commitments.
|
||||
pub fn encoding(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
|
||||
self.encoding.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -31,6 +31,7 @@ web-spawn = { workspace = true, optional = true }
|
||||
|
||||
mpz-common = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-garble-core = { workspace = true }
|
||||
mpz-hash = { workspace = true }
|
||||
|
||||
@@ -1,116 +1,5 @@
|
||||
//! Plaintext commitment and proof of encryption.
|
||||
|
||||
pub(crate) mod auth;
|
||||
pub(crate) mod hash;
|
||||
pub(crate) mod transcript;
|
||||
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{
|
||||
DecodeFutureTyped, Vector,
|
||||
binary::{Binary, U8},
|
||||
};
|
||||
use mpz_vm_core::{Vm, prelude::*};
|
||||
use tlsn_core::transcript::Record;
|
||||
|
||||
use crate::{
|
||||
Role,
|
||||
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
|
||||
};
|
||||
|
||||
/// Commits the plaintext of the provided records, returning a proof of
|
||||
/// encryption.
|
||||
///
|
||||
/// Writes the plaintext VM reference to the provided records.
|
||||
pub(crate) fn commit_records<'record>(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
aes: &mut ZkAesCtr,
|
||||
records: impl IntoIterator<Item = &'record Record>,
|
||||
) -> Result<(Vec<Vector<U8>>, RecordProof), RecordProofError> {
|
||||
let mut plaintexts = Vec::new();
|
||||
let mut ciphertexts = Vec::new();
|
||||
for record in records {
|
||||
let (plaintext_ref, ciphertext_ref) = aes
|
||||
.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())
|
||||
.map_err(ErrorRepr::Aes)?;
|
||||
|
||||
if let Role::Prover = aes.role() {
|
||||
let Some(plaintext) = record.plaintext.clone() else {
|
||||
return Err(ErrorRepr::MissingPlaintext.into());
|
||||
};
|
||||
|
||||
vm.assign(plaintext_ref, plaintext)
|
||||
.map_err(RecordProofError::vm)?;
|
||||
}
|
||||
vm.commit(plaintext_ref).map_err(RecordProofError::vm)?;
|
||||
|
||||
let ciphertext = vm.decode(ciphertext_ref).map_err(RecordProofError::vm)?;
|
||||
|
||||
plaintexts.push(plaintext_ref);
|
||||
ciphertexts.push((ciphertext, record.ciphertext.clone()));
|
||||
}
|
||||
|
||||
Ok((plaintexts, RecordProof { ciphertexts }))
|
||||
}
|
||||
|
||||
/// Proof of encryption.
|
||||
#[derive(Debug)]
|
||||
#[must_use]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) struct RecordProof {
|
||||
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
|
||||
}
|
||||
|
||||
impl RecordProof {
|
||||
/// Verifies the proof.
|
||||
pub(crate) fn verify(self) -> Result<(), RecordProofError> {
|
||||
let Self { ciphertexts } = self;
|
||||
|
||||
for (mut ciphertext, expected) in ciphertexts {
|
||||
let ciphertext = ciphertext
|
||||
.try_recv()
|
||||
.map_err(RecordProofError::vm)?
|
||||
.ok_or_else(|| ErrorRepr::NotDecoded)?;
|
||||
|
||||
if ciphertext != expected {
|
||||
return Err(ErrorRepr::InvalidCiphertext.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`RecordProof`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub(crate) struct RecordProofError(#[from] ErrorRepr);
|
||||
|
||||
impl RecordProofError {
|
||||
fn vm<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Vm(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn is_insufficient(&self) -> bool {
|
||||
match &self.0 {
|
||||
ErrorRepr::Aes(err) => err.is_insufficient(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("record proof error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("VM error: {0}")]
|
||||
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("zk aes error: {0}")]
|
||||
Aes(ZkAesCtrError),
|
||||
#[error("plaintext is missing")]
|
||||
MissingPlaintext,
|
||||
#[error("ciphertext was not decoded")]
|
||||
NotDecoded,
|
||||
#[error("ciphertext does not match expected")]
|
||||
InvalidCiphertext,
|
||||
}
|
||||
|
||||
166
crates/tlsn/src/commit/auth.rs
Normal file
166
crates/tlsn/src/commit/auth.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{DecodeFutureTyped, MemoryExt, ViewExt, binary::Binary};
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::{Difference, RangeSet, Subset};
|
||||
|
||||
use crate::{
|
||||
commit::transcript::ReferenceMap,
|
||||
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
|
||||
};
|
||||
|
||||
pub(crate) fn prove_plaintext(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
zk_aes: &mut ZkAesCtr,
|
||||
plaintext: &[u8],
|
||||
ranges: &RangeSet<usize>,
|
||||
public: &RangeSet<usize>,
|
||||
) -> Result<ReferenceMap, PlaintextAuthError> {
|
||||
assert!(public.is_subset(ranges), "public is not a subset of ranges");
|
||||
|
||||
if ranges.is_empty() {
|
||||
return Ok(ReferenceMap::default());
|
||||
}
|
||||
|
||||
let (plaintext_map, ciphertext_map) = zk_aes
|
||||
.alloc_plaintext(vm, ranges)
|
||||
.map_err(ErrorRepr::ZkAesCtr)?;
|
||||
|
||||
for (range, chunk) in plaintext_map
|
||||
.index(&ranges.difference(public))
|
||||
.expect("map contains all ranges")
|
||||
.iter()
|
||||
{
|
||||
vm.mark_private(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
vm.assign(*chunk, plaintext[range].to_vec())
|
||||
.map_err(PlaintextAuthError::vm)?;
|
||||
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
}
|
||||
|
||||
for (range, chunk) in plaintext_map
|
||||
.index(public)
|
||||
.expect("map contains all ranges")
|
||||
.iter()
|
||||
{
|
||||
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
vm.assign(*chunk, plaintext[range].to_vec())
|
||||
.map_err(PlaintextAuthError::vm)?;
|
||||
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
}
|
||||
|
||||
for (_, chunk) in ciphertext_map.iter() {
|
||||
drop(vm.decode(*chunk).map_err(PlaintextAuthError::vm)?);
|
||||
}
|
||||
|
||||
Ok(plaintext_map)
|
||||
}
|
||||
|
||||
pub(crate) fn verify_plaintext(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
zk_aes: &mut ZkAesCtr,
|
||||
plaintext: &[u8],
|
||||
ciphertext: &[u8],
|
||||
ranges: &RangeSet<usize>,
|
||||
public: &RangeSet<usize>,
|
||||
) -> Result<(ReferenceMap, PlaintextProof), PlaintextAuthError> {
|
||||
assert!(public.is_subset(ranges), "public is not a subset of ranges");
|
||||
|
||||
if ranges.is_empty() {
|
||||
return Ok((
|
||||
ReferenceMap::default(),
|
||||
PlaintextProof {
|
||||
ciphertexts: vec![],
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let (plaintext_map, ciphertext_map) = zk_aes
|
||||
.alloc_plaintext(vm, ranges)
|
||||
.map_err(ErrorRepr::ZkAesCtr)?;
|
||||
|
||||
for (_, chunk) in plaintext_map
|
||||
.index(&ranges.difference(public))
|
||||
.expect("map contains all ranges")
|
||||
.iter()
|
||||
{
|
||||
vm.mark_blind(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
}
|
||||
|
||||
for (range, chunk) in plaintext_map
|
||||
.index(public)
|
||||
.expect("map contains all ranges")
|
||||
.iter()
|
||||
{
|
||||
vm.mark_public(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
vm.assign(*chunk, plaintext[range].to_vec())
|
||||
.map_err(PlaintextAuthError::vm)?;
|
||||
vm.commit(*chunk).map_err(PlaintextAuthError::vm)?;
|
||||
}
|
||||
|
||||
let mut ciphertexts = Vec::new();
|
||||
for (range, chunk) in ciphertext_map
|
||||
.index(ranges)
|
||||
.expect("map contains all ranges")
|
||||
.iter()
|
||||
{
|
||||
ciphertexts.push((
|
||||
ciphertext[range].to_vec(),
|
||||
vm.decode(*chunk).map_err(PlaintextAuthError::vm)?,
|
||||
));
|
||||
}
|
||||
|
||||
Ok((plaintext_map, PlaintextProof { ciphertexts }))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("plaintext authentication error: {0}")]
|
||||
pub(crate) struct PlaintextAuthError(#[from] ErrorRepr);
|
||||
|
||||
impl PlaintextAuthError {
|
||||
fn vm<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Vm(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {
|
||||
#[error("vm error: {0}")]
|
||||
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("zk aes ctr error: {0}")]
|
||||
ZkAesCtr(ZkAesCtrError),
|
||||
#[error("missing decoding")]
|
||||
MissingDecoding,
|
||||
#[error("invalid ciphertext")]
|
||||
InvalidCiphertext,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) struct PlaintextProof {
|
||||
// (expected, actual)
|
||||
#[allow(clippy::type_complexity)]
|
||||
ciphertexts: Vec<(Vec<u8>, DecodeFutureTyped<BitVec, Vec<u8>>)>,
|
||||
}
|
||||
|
||||
impl PlaintextProof {
|
||||
pub(crate) fn verify(self) -> Result<(), PlaintextAuthError> {
|
||||
let Self {
|
||||
ciphertexts: ciphertext,
|
||||
} = self;
|
||||
|
||||
for (expected, mut actual) in ciphertext {
|
||||
let actual = actual
|
||||
.try_recv()
|
||||
.map_err(PlaintextAuthError::vm)?
|
||||
.ok_or(PlaintextAuthError(ErrorRepr::MissingDecoding))?;
|
||||
|
||||
if actual != expected {
|
||||
return Err(PlaintextAuthError(ErrorRepr::InvalidCiphertext));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -149,9 +149,15 @@ fn hash_commit_inner(
|
||||
hasher
|
||||
};
|
||||
|
||||
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
|
||||
hasher.update(&plaintext);
|
||||
let refs = match direction {
|
||||
Direction::Sent => &refs.sent,
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
|
||||
}
|
||||
|
||||
hasher.update(&blinder);
|
||||
hasher.finalize(vm).map_err(HashCommitError::hasher)?
|
||||
}
|
||||
@@ -164,9 +170,14 @@ fn hash_commit_inner(
|
||||
hasher
|
||||
};
|
||||
|
||||
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
|
||||
let refs = match direction {
|
||||
Direction::Sent => &refs.sent,
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
hasher
|
||||
.update(vm, &plaintext)
|
||||
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
|
||||
.map_err(HashCommitError::hasher)?;
|
||||
}
|
||||
hasher
|
||||
|
||||
@@ -1,211 +1,205 @@
|
||||
use mpz_memory_core::{
|
||||
MemoryExt, Vector,
|
||||
binary::{Binary, U8},
|
||||
};
|
||||
use mpz_vm_core::{Vm, VmError};
|
||||
use rangeset::{Intersection, RangeSet};
|
||||
use tlsn_core::transcript::{Direction, PartialTranscript};
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_memory_core::{Vector, binary::U8};
|
||||
use rangeset::RangeSet;
|
||||
|
||||
pub(crate) type ReferenceMap = RangeMap<Vector<U8>>;
|
||||
|
||||
/// References to the application plaintext in the transcript.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub(crate) struct TranscriptRefs {
|
||||
sent: Vec<Vector<U8>>,
|
||||
recv: Vec<Vector<U8>>,
|
||||
pub(crate) sent: ReferenceMap,
|
||||
pub(crate) recv: ReferenceMap,
|
||||
}
|
||||
|
||||
impl TranscriptRefs {
|
||||
pub(crate) fn new(sent: Vec<Vector<U8>>, recv: Vec<Vector<U8>>) -> Self {
|
||||
Self { sent, recv }
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct RangeMap<T> {
|
||||
map: Vec<(usize, T)>,
|
||||
}
|
||||
|
||||
impl<T> Default for RangeMap<T>
|
||||
where
|
||||
T: Item,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self { map: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sent plaintext references.
|
||||
pub(crate) fn sent(&self) -> &[Vector<U8>] {
|
||||
&self.sent
|
||||
}
|
||||
impl<T> RangeMap<T>
|
||||
where
|
||||
T: Item,
|
||||
{
|
||||
pub(crate) fn new(map: Vec<(usize, T)>) -> Self {
|
||||
let mut pos = 0;
|
||||
for (idx, item) in &map {
|
||||
assert!(
|
||||
*idx >= pos,
|
||||
"items must be sorted by index and non-overlapping"
|
||||
);
|
||||
|
||||
/// Returns the received plaintext references.
|
||||
pub(crate) fn recv(&self) -> &[Vector<U8>] {
|
||||
&self.recv
|
||||
}
|
||||
|
||||
/// Returns the transcript lengths.
|
||||
pub(crate) fn len(&self) -> (usize, usize) {
|
||||
let sent = self.sent.iter().map(|v| v.len()).sum();
|
||||
let recv = self.recv.iter().map(|v| v.len()).sum();
|
||||
|
||||
(sent, recv)
|
||||
}
|
||||
|
||||
/// Returns VM references for the given direction and index, otherwise
|
||||
/// `None` if the index is out of bounds.
|
||||
pub(crate) fn get(
|
||||
&self,
|
||||
direction: Direction,
|
||||
idx: &RangeSet<usize>,
|
||||
) -> Option<Vec<Vector<U8>>> {
|
||||
if idx.is_empty() {
|
||||
return Some(Vec::new());
|
||||
pos = *idx + item.length();
|
||||
}
|
||||
|
||||
let refs = match direction {
|
||||
Direction::Sent => &self.sent,
|
||||
Direction::Received => &self.recv,
|
||||
Self { map }
|
||||
}
|
||||
|
||||
/// Returns the length of the map.
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.map.iter().map(|(_, item)| item.length()).sum()
|
||||
}
|
||||
|
||||
pub(crate) fn iter(&self) -> impl Iterator<Item = (Range<usize>, &T)> {
|
||||
self.map
|
||||
.iter()
|
||||
.map(|(idx, item)| (*idx..*idx + item.length(), item))
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self, range: Range<usize>) -> Option<T::Slice<'_>> {
|
||||
if range.start >= range.end {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Find the item with the greatest start index <= range.start
|
||||
let pos = match self.map.binary_search_by(|(idx, _)| idx.cmp(&range.start)) {
|
||||
Ok(i) => i,
|
||||
Err(0) => return None,
|
||||
Err(i) => i - 1,
|
||||
};
|
||||
|
||||
// Computes the transcript range for each reference.
|
||||
let mut start = 0;
|
||||
let mut slice_iter = refs.iter().map(move |slice| {
|
||||
let out = (slice, start..start + slice.len());
|
||||
start += slice.len();
|
||||
out
|
||||
});
|
||||
let (base, item) = &self.map[pos];
|
||||
|
||||
let mut slices = Vec::new();
|
||||
let (mut slice, mut slice_range) = slice_iter.next()?;
|
||||
for range in idx.iter_ranges() {
|
||||
loop {
|
||||
if let Some(intersection) = slice_range.intersection(&range) {
|
||||
let start = intersection.start - slice_range.start;
|
||||
let end = intersection.end - slice_range.start;
|
||||
slices.push(slice.get(start..end).expect("range should be in bounds"));
|
||||
}
|
||||
item.slice(range.start - *base..range.end - *base)
|
||||
}
|
||||
|
||||
// Proceed to next range if the current slice extends beyond. Otherwise, proceed
|
||||
// to the next slice.
|
||||
if range.end <= slice_range.end {
|
||||
break;
|
||||
} else {
|
||||
(slice, slice_range) = slice_iter.next()?;
|
||||
}
|
||||
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
|
||||
let mut map = Vec::new();
|
||||
for idx in idx.iter_ranges() {
|
||||
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
|
||||
Ok(i) => i,
|
||||
Err(0) => return None,
|
||||
Err(i) => i - 1,
|
||||
};
|
||||
|
||||
let (base, item) = self.map.get(pos)?;
|
||||
if idx.start < *base || idx.end > *base + item.length() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let start = idx.start - *base;
|
||||
let end = start + idx.len();
|
||||
|
||||
map.push((
|
||||
idx.start,
|
||||
item.slice(start..end)
|
||||
.expect("slice length is checked")
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
Some(slices)
|
||||
Some(Self { map })
|
||||
}
|
||||
}
|
||||
|
||||
/// Decodes the transcript.
|
||||
pub(crate) fn decode_transcript(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
sent: &RangeSet<usize>,
|
||||
recv: &RangeSet<usize>,
|
||||
refs: &TranscriptRefs,
|
||||
) -> Result<(), VmError> {
|
||||
let sent_refs = refs.get(Direction::Sent, sent).expect("index is in bounds");
|
||||
let recv_refs = refs
|
||||
.get(Direction::Received, recv)
|
||||
.expect("index is in bounds");
|
||||
impl<T> FromIterator<(usize, T)> for RangeMap<T>
|
||||
where
|
||||
T: Item,
|
||||
{
|
||||
fn from_iter<I: IntoIterator<Item = (usize, T)>>(items: I) -> Self {
|
||||
let mut pos = 0;
|
||||
let mut map = Vec::new();
|
||||
for (idx, item) in items {
|
||||
assert!(
|
||||
idx >= pos,
|
||||
"items must be sorted by index and non-overlapping"
|
||||
);
|
||||
|
||||
for slice in sent_refs.into_iter().chain(recv_refs) {
|
||||
// Drop the future, we don't need it.
|
||||
drop(vm.decode(slice)?);
|
||||
pos = idx + item.length();
|
||||
map.push((idx, item));
|
||||
}
|
||||
|
||||
Self { map }
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verifies a partial transcript.
|
||||
pub(crate) fn verify_transcript(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
transcript: &PartialTranscript,
|
||||
refs: &TranscriptRefs,
|
||||
) -> Result<(), InconsistentTranscript> {
|
||||
let sent_refs = refs
|
||||
.get(Direction::Sent, transcript.sent_authed())
|
||||
.expect("index is in bounds");
|
||||
let recv_refs = refs
|
||||
.get(Direction::Received, transcript.received_authed())
|
||||
.expect("index is in bounds");
|
||||
pub(crate) trait Item: Sized {
|
||||
type Slice<'a>: Into<Self>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
let mut authenticated_data = Vec::new();
|
||||
for data in sent_refs.into_iter().chain(recv_refs) {
|
||||
let plaintext = vm
|
||||
.get(data)
|
||||
.expect("reference is valid")
|
||||
.expect("plaintext is decoded");
|
||||
authenticated_data.extend_from_slice(&plaintext);
|
||||
}
|
||||
fn length(&self) -> usize;
|
||||
|
||||
let mut purported_data = Vec::with_capacity(authenticated_data.len());
|
||||
for range in transcript.sent_authed().iter_ranges() {
|
||||
purported_data.extend_from_slice(&transcript.sent_unsafe()[range]);
|
||||
}
|
||||
|
||||
for range in transcript.received_authed().iter_ranges() {
|
||||
purported_data.extend_from_slice(&transcript.received_unsafe()[range]);
|
||||
}
|
||||
|
||||
if purported_data != authenticated_data {
|
||||
return Err(InconsistentTranscript {});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>>;
|
||||
}
|
||||
|
||||
/// Error for [`verify_transcript`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("inconsistent transcript")]
|
||||
pub(crate) struct InconsistentTranscript {}
|
||||
impl Item for Vector<U8> {
|
||||
type Slice<'a> = Vector<U8>;
|
||||
|
||||
fn length(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
|
||||
self.get(range)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::TranscriptRefs;
|
||||
use mpz_memory_core::{FromRaw, Slice, Vector, binary::U8};
|
||||
use rangeset::RangeSet;
|
||||
use std::ops::Range;
|
||||
use tlsn_core::transcript::Direction;
|
||||
use super::*;
|
||||
|
||||
// TRANSCRIPT_REFS:
|
||||
//
|
||||
// 48..96 -> 6 slots
|
||||
// 112..176 -> 8 slots
|
||||
// 240..288 -> 6 slots
|
||||
// 352..392 -> 5 slots
|
||||
// 440..480 -> 5 slots
|
||||
const TRANSCRIPT_REFS: &[Range<usize>] = &[48..96, 112..176, 240..288, 352..392, 440..480];
|
||||
impl Item for Range<usize> {
|
||||
type Slice<'a> = Range<usize>;
|
||||
|
||||
const IDXS: &[Range<usize>] = &[0..4, 5..10, 14..16, 16..28];
|
||||
fn length(&self) -> usize {
|
||||
self.end - self.start
|
||||
}
|
||||
|
||||
// 1. Take slots 0..4, 4 slots -> 48..80 (4)
|
||||
// 2. Take slots 5..10, 5 slots -> 88..96 (1) + 112..144 (4)
|
||||
// 3. Take slots 14..16, 2 slots -> 240..256 (2)
|
||||
// 4. Take slots 16..28, 12 slots -> 256..288 (4) + 352..392 (5) + 440..464 (3)
|
||||
//
|
||||
// 5. Merge slots 240..256 and 256..288 => 240..288 and get EXPECTED_REFS
|
||||
const EXPECTED_REFS: &[Range<usize>] =
|
||||
&[48..80, 88..96, 112..144, 240..288, 352..392, 440..464];
|
||||
fn slice(&self, range: Range<usize>) -> Option<Self> {
|
||||
if range.end > self.end - self.start {
|
||||
return None;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transcript_refs_get() {
|
||||
let transcript_refs: Vec<Vector<U8>> = TRANSCRIPT_REFS
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
|
||||
.collect();
|
||||
|
||||
let transcript_refs = TranscriptRefs {
|
||||
sent: transcript_refs.clone(),
|
||||
recv: transcript_refs,
|
||||
};
|
||||
|
||||
let vm_refs = transcript_refs
|
||||
.get(Direction::Sent, &RangeSet::from(IDXS))
|
||||
.unwrap();
|
||||
|
||||
let expected_refs: Vec<Vector<U8>> = EXPECTED_REFS
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
vm_refs.len(),
|
||||
expected_refs.len(),
|
||||
"Length of actual and expected refs are not equal"
|
||||
);
|
||||
|
||||
for (&expected, actual) in expected_refs.iter().zip(vm_refs) {
|
||||
assert_eq!(expected, actual);
|
||||
Some(range.start + self.start..range.end + self.start)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_map() {
|
||||
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
|
||||
|
||||
assert_eq!(map.get(0..4), Some(10..14));
|
||||
assert_eq!(map.get(10..14), Some(20..24));
|
||||
assert_eq!(map.get(20..22), Some(30..32));
|
||||
assert_eq!(map.get(0..2), Some(10..12));
|
||||
assert_eq!(map.get(11..13), Some(21..23));
|
||||
assert_eq!(map.get(0..10), None);
|
||||
assert_eq!(map.get(10..20), None);
|
||||
assert_eq!(map.get(20..30), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_map_index() {
|
||||
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
|
||||
|
||||
let idx = RangeSet::from([0..4, 10..14, 20..22]);
|
||||
assert_eq!(map.index(&idx), Some(map.clone()));
|
||||
|
||||
let idx = RangeSet::from(25..30);
|
||||
assert_eq!(map.index(&idx), None);
|
||||
|
||||
let idx = RangeSet::from(15..20);
|
||||
assert_eq!(map.index(&idx), None);
|
||||
|
||||
let idx = RangeSet::from([1..3, 11..12, 13..14, 21..22]);
|
||||
assert_eq!(
|
||||
map.index(&idx),
|
||||
Some(RangeMap::from_iter([
|
||||
(1, 11..13),
|
||||
(11, 21..22),
|
||||
(13, 23..24),
|
||||
(21, 31..32)
|
||||
]))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use rangeset::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use tlsn_core::{
|
||||
hash::HashAlgorithm,
|
||||
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
|
||||
transcript::{
|
||||
Direction,
|
||||
encoding::{
|
||||
@@ -23,7 +23,7 @@ use tlsn_core::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::commit::transcript::TranscriptRefs;
|
||||
use crate::commit::transcript::{Item, RangeMap, ReferenceMap};
|
||||
|
||||
/// Bytes of encoding, per byte.
|
||||
const ENCODING_SIZE: usize = 128;
|
||||
@@ -34,145 +34,130 @@ struct Encodings {
|
||||
recv: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Transfers the encodings using the provided seed and keys.
|
||||
///
|
||||
/// The keys must be consistent with the global delta used in the encodings.
|
||||
pub(crate) async fn transfer<'a>(
|
||||
/// Transfers encodings for the provided plaintext ranges.
|
||||
pub(crate) async fn transfer<K: KeyStore>(
|
||||
ctx: &mut Context,
|
||||
refs: &TranscriptRefs,
|
||||
delta: &Delta,
|
||||
f: impl Fn(Vector<U8>) -> &'a [Key],
|
||||
store: &K,
|
||||
sent: &ReferenceMap,
|
||||
recv: &ReferenceMap,
|
||||
) -> Result<EncodingCommitment, EncodingError> {
|
||||
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes());
|
||||
let secret = EncoderSecret::new(rand::rng().random(), store.delta().as_block().to_bytes());
|
||||
let encoder = new_encoder(&secret);
|
||||
|
||||
let sent_keys: Vec<u8> = refs
|
||||
.sent()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|key| key.as_block().as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
let recv_keys: Vec<u8> = refs
|
||||
.recv()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|key| key.as_block().as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
// Collects the encodings for the provided plaintext ranges.
|
||||
fn collect_encodings(
|
||||
encoder: &impl Encoder,
|
||||
store: &impl KeyStore,
|
||||
direction: Direction,
|
||||
map: &ReferenceMap,
|
||||
) -> Vec<u8> {
|
||||
let mut encodings = Vec::with_capacity(map.len() * ENCODING_SIZE);
|
||||
for (range, chunk) in map.iter() {
|
||||
let start = encodings.len();
|
||||
encoder.encode_range(direction, range, &mut encodings);
|
||||
let keys = store
|
||||
.get_keys(*chunk)
|
||||
.expect("keys are present for provided plaintext ranges");
|
||||
encodings[start..]
|
||||
.iter_mut()
|
||||
.zip(keys.iter().flat_map(|key| key.as_block().as_bytes()))
|
||||
.for_each(|(encoding, key)| {
|
||||
*encoding ^= *key;
|
||||
});
|
||||
}
|
||||
encodings
|
||||
}
|
||||
|
||||
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
|
||||
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
|
||||
let encodings = Encodings {
|
||||
sent: collect_encodings(&encoder, store, Direction::Sent, sent),
|
||||
recv: collect_encodings(&encoder, store, Direction::Received, recv),
|
||||
};
|
||||
|
||||
let mut sent_encoding = Vec::with_capacity(sent_keys.len());
|
||||
let mut recv_encoding = Vec::with_capacity(recv_keys.len());
|
||||
|
||||
encoder.encode_range(
|
||||
Direction::Sent,
|
||||
0..sent_keys.len() / ENCODING_SIZE,
|
||||
&mut sent_encoding,
|
||||
);
|
||||
encoder.encode_range(
|
||||
Direction::Received,
|
||||
0..recv_keys.len() / ENCODING_SIZE,
|
||||
&mut recv_encoding,
|
||||
);
|
||||
|
||||
sent_encoding
|
||||
.iter_mut()
|
||||
.zip(sent_keys)
|
||||
.for_each(|(enc, key)| *enc ^= key);
|
||||
recv_encoding
|
||||
.iter_mut()
|
||||
.zip(recv_keys)
|
||||
.for_each(|(enc, key)| *enc ^= key);
|
||||
|
||||
// Set frame limit and add some extra bytes cushion room.
|
||||
let (sent, recv) = refs.len();
|
||||
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
|
||||
|
||||
ctx.io_mut()
|
||||
.with_limit(frame_limit)
|
||||
.send(Encodings {
|
||||
sent: sent_encoding,
|
||||
recv: recv_encoding,
|
||||
})
|
||||
.await?;
|
||||
let frame_limit = ctx
|
||||
.io()
|
||||
.limit()
|
||||
.saturating_add(encodings.sent.len() + encodings.recv.len());
|
||||
ctx.io_mut().with_limit(frame_limit).send(encodings).await?;
|
||||
|
||||
let root = ctx.io_mut().expect_next().await?;
|
||||
ctx.io_mut().send(secret.clone()).await?;
|
||||
|
||||
Ok(EncodingCommitment {
|
||||
root,
|
||||
secret: secret.clone(),
|
||||
})
|
||||
Ok(EncodingCommitment { root, secret })
|
||||
}
|
||||
|
||||
/// Receives the encodings using the provided MACs.
|
||||
///
|
||||
/// The MACs must be consistent with the global delta used in the encodings.
|
||||
pub(crate) async fn receive<'a>(
|
||||
/// Receives and commits to the encodings for the provided plaintext ranges.
|
||||
pub(crate) async fn receive<M: MacStore>(
|
||||
ctx: &mut Context,
|
||||
hasher: &(dyn HashAlgorithm + Send + Sync),
|
||||
refs: &TranscriptRefs,
|
||||
f: impl Fn(Vector<U8>) -> &'a [Mac],
|
||||
store: &M,
|
||||
hash_alg: HashAlgId,
|
||||
sent: &ReferenceMap,
|
||||
recv: &ReferenceMap,
|
||||
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
|
||||
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
|
||||
// Set frame limit and add some extra bytes cushion room.
|
||||
let (sent, recv) = refs.len();
|
||||
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
|
||||
let hasher: &(dyn HashAlgorithm + Send + Sync) = match hash_alg {
|
||||
HashAlgId::SHA256 => &Sha256::default(),
|
||||
HashAlgId::KECCAK256 => &Keccak256::default(),
|
||||
HashAlgId::BLAKE3 => &Blake3::default(),
|
||||
alg => {
|
||||
return Err(ErrorRepr::UnsupportedHashAlgorithm(alg).into());
|
||||
}
|
||||
};
|
||||
|
||||
let Encodings { mut sent, mut recv } =
|
||||
ctx.io_mut().with_limit(frame_limit).expect_next().await?;
|
||||
let (sent_len, recv_len) = (sent.len(), recv.len());
|
||||
let frame_limit = ctx
|
||||
.io()
|
||||
.limit()
|
||||
.saturating_add(ENCODING_SIZE * (sent_len + recv_len));
|
||||
let encodings: Encodings = ctx.io_mut().with_limit(frame_limit).expect_next().await?;
|
||||
|
||||
let sent_macs: Vec<u8> = refs
|
||||
.sent()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|mac| mac.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
let recv_macs: Vec<u8> = refs
|
||||
.recv()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|mac| mac.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
|
||||
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
|
||||
|
||||
if sent.len() != sent_macs.len() {
|
||||
if encodings.sent.len() != sent_len * ENCODING_SIZE {
|
||||
return Err(ErrorRepr::IncorrectMacCount {
|
||||
direction: Direction::Sent,
|
||||
expected: sent_macs.len(),
|
||||
got: sent.len(),
|
||||
expected: sent_len,
|
||||
got: encodings.sent.len() / ENCODING_SIZE,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
if recv.len() != recv_macs.len() {
|
||||
if encodings.recv.len() != recv_len * ENCODING_SIZE {
|
||||
return Err(ErrorRepr::IncorrectMacCount {
|
||||
direction: Direction::Received,
|
||||
expected: recv_macs.len(),
|
||||
got: recv.len(),
|
||||
expected: recv_len,
|
||||
got: encodings.recv.len() / ENCODING_SIZE,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
sent.iter_mut()
|
||||
.zip(sent_macs)
|
||||
.for_each(|(enc, mac)| *enc ^= mac);
|
||||
recv.iter_mut()
|
||||
.zip(recv_macs)
|
||||
.for_each(|(enc, mac)| *enc ^= mac);
|
||||
// Collects a map of plaintext ranges to their encodings.
|
||||
fn collect_map(
|
||||
store: &impl MacStore,
|
||||
mut encodings: Vec<u8>,
|
||||
map: &ReferenceMap,
|
||||
) -> RangeMap<EncodingSlice> {
|
||||
let mut encoding_map = Vec::new();
|
||||
let mut pos = 0;
|
||||
for (range, chunk) in map.iter() {
|
||||
let macs = store
|
||||
.get_macs(*chunk)
|
||||
.expect("MACs are present for provided plaintext ranges");
|
||||
let encoding = &mut encodings[pos..pos + range.len() * ENCODING_SIZE];
|
||||
encoding
|
||||
.iter_mut()
|
||||
.zip(macs.iter().flat_map(|mac| mac.as_bytes()))
|
||||
.for_each(|(encoding, mac)| {
|
||||
*encoding ^= *mac;
|
||||
});
|
||||
|
||||
let provider = Provider { sent, recv };
|
||||
encoding_map.push((range.start, EncodingSlice::from(&(*encoding))));
|
||||
pos += range.len() * ENCODING_SIZE;
|
||||
}
|
||||
RangeMap::new(encoding_map)
|
||||
}
|
||||
|
||||
let provider = Provider {
|
||||
sent: collect_map(store, encodings.sent, sent),
|
||||
recv: collect_map(store, encodings.recv, recv),
|
||||
};
|
||||
|
||||
let tree = EncodingTree::new(hasher, idxs, &provider)?;
|
||||
let root = tree.root();
|
||||
@@ -185,10 +170,36 @@ pub(crate) async fn receive<'a>(
|
||||
Ok((commitment, tree))
|
||||
}
|
||||
|
||||
pub(crate) trait KeyStore {
|
||||
fn delta(&self) -> Δ
|
||||
|
||||
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
|
||||
}
|
||||
|
||||
impl KeyStore for crate::verifier::Zk {
|
||||
fn delta(&self) -> &Delta {
|
||||
crate::verifier::Zk::delta(self)
|
||||
}
|
||||
|
||||
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
|
||||
self.get_keys(data).ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait MacStore {
|
||||
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
|
||||
}
|
||||
|
||||
impl MacStore for crate::prover::Zk {
|
||||
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
|
||||
self.get_macs(data).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Provider {
|
||||
sent: Vec<u8>,
|
||||
recv: Vec<u8>,
|
||||
sent: RangeMap<EncodingSlice>,
|
||||
recv: RangeMap<EncodingSlice>,
|
||||
}
|
||||
|
||||
impl EncodingProvider for Provider {
|
||||
@@ -203,19 +214,39 @@ impl EncodingProvider for Provider {
|
||||
Direction::Received => &self.recv,
|
||||
};
|
||||
|
||||
let start = range.start * ENCODING_SIZE;
|
||||
let end = range.end * ENCODING_SIZE;
|
||||
let encoding = encodings.get(range).ok_or(EncodingProviderError)?;
|
||||
|
||||
if end > encodings.len() {
|
||||
return Err(EncodingProviderError);
|
||||
}
|
||||
|
||||
dest.extend_from_slice(&encodings[start..end]);
|
||||
dest.extend_from_slice(encoding);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodingSlice(Vec<u8>);
|
||||
|
||||
impl From<&[u8]> for EncodingSlice {
|
||||
fn from(value: &[u8]) -> Self {
|
||||
Self(value.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Item for EncodingSlice {
|
||||
type Slice<'a>
|
||||
= &'a [u8]
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
fn length(&self) -> usize {
|
||||
self.0.len() / ENCODING_SIZE
|
||||
}
|
||||
|
||||
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
|
||||
self.0
|
||||
.get(range.start * ENCODING_SIZE..range.end * ENCODING_SIZE)
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoding protocol error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
@@ -234,6 +265,8 @@ enum ErrorRepr {
|
||||
},
|
||||
#[error("encoding tree error: {0}")]
|
||||
EncodingTree(EncodingTreeError),
|
||||
#[error("unsupported hash algorithm: {0}")]
|
||||
UnsupportedHashAlgorithm(HashAlgId),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for EncodingError {
|
||||
|
||||
@@ -9,7 +9,6 @@ pub mod config;
|
||||
pub(crate) mod context;
|
||||
pub(crate) mod encoding;
|
||||
pub(crate) mod ghash;
|
||||
pub(crate) mod msg;
|
||||
pub(crate) mod mux;
|
||||
pub mod prover;
|
||||
pub(crate) mod tag;
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
//! Message types.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use tlsn_core::connection::{HandshakeData, ServerName};
|
||||
|
||||
/// Message sent from Prover to Verifier to prove the server identity.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct ServerIdentityProof {
|
||||
/// Server name.
|
||||
pub name: ServerName,
|
||||
/// Server identity data.
|
||||
pub data: HandshakeData,
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
mod config;
|
||||
mod error;
|
||||
mod future;
|
||||
mod prove;
|
||||
pub mod state;
|
||||
|
||||
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
|
||||
@@ -18,19 +19,7 @@ use mpz_vm_core::prelude::*;
|
||||
use mpz_zk::ProverConfig as ZkProverConfig;
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
use crate::{
|
||||
Role,
|
||||
commit::{
|
||||
commit_records,
|
||||
hash::prove_hash,
|
||||
transcript::{TranscriptRefs, decode_transcript},
|
||||
},
|
||||
context::build_mt_context,
|
||||
encoding,
|
||||
mux::attach_mux,
|
||||
tag::verify_tags,
|
||||
zk_aes_ctr::ZkAesCtr,
|
||||
};
|
||||
use crate::{Role, context::build_mt_context, mux::attach_mux, tag::verify_tags};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
|
||||
@@ -39,12 +28,9 @@ use serio::SinkExt;
|
||||
use std::sync::Arc;
|
||||
use tls_client::{ClientConnection, ServerName as TlsServerName};
|
||||
use tls_client_async::{TlsConnection, bind_client};
|
||||
use tls_core::msgs::enums::ContentType;
|
||||
use tlsn_core::{
|
||||
ProvePayload,
|
||||
connection::{HandshakeData, ServerName},
|
||||
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
|
||||
transcript::{TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret},
|
||||
connection::ServerName,
|
||||
transcript::{TlsTranscript, Transcript},
|
||||
};
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -115,22 +101,6 @@ impl Prover<state::Initialized> {
|
||||
let mut keys = mpc_tls.alloc()?;
|
||||
let vm_lock = vm.try_lock().expect("VM is not locked");
|
||||
translate_keys(&mut keys, &vm_lock)?;
|
||||
|
||||
// Allocate for committing to plaintext.
|
||||
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Prover);
|
||||
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
|
||||
zk_aes_ctr_sent.alloc(
|
||||
&mut *vm_lock.zk(),
|
||||
self.config.protocol_config().max_sent_data(),
|
||||
)?;
|
||||
|
||||
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Prover);
|
||||
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
|
||||
zk_aes_ctr_recv.alloc(
|
||||
&mut *vm_lock.zk(),
|
||||
self.config.protocol_config().max_recv_data(),
|
||||
)?;
|
||||
|
||||
drop(vm_lock);
|
||||
|
||||
debug!("setting up mpc-tls");
|
||||
@@ -146,8 +116,6 @@ impl Prover<state::Initialized> {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
zk_aes_ctr_sent,
|
||||
zk_aes_ctr_recv,
|
||||
keys,
|
||||
vm,
|
||||
},
|
||||
@@ -173,8 +141,6 @@ impl Prover<state::Setup> {
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mpc_tls,
|
||||
mut zk_aes_ctr_sent,
|
||||
mut zk_aes_ctr_recv,
|
||||
keys,
|
||||
vm,
|
||||
..
|
||||
@@ -281,35 +247,6 @@ impl Prover<state::Setup> {
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
|
||||
// Prove sent and received plaintext. Prover drops the proof
|
||||
// output, as they trust themselves.
|
||||
let (sent_refs, _) = commit_records(
|
||||
&mut vm,
|
||||
&mut zk_aes_ctr_sent,
|
||||
tls_transcript
|
||||
.sent()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
|
||||
|
||||
let (recv_refs, _) = commit_records(
|
||||
&mut vm,
|
||||
&mut zk_aes_ctr_recv,
|
||||
tls_transcript
|
||||
.recv()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
)
|
||||
.map_err(|e| {
|
||||
if e.is_insufficient() {
|
||||
ProverError::zk(format!("{e}. Attempted to prove more received data than was configured, increase `max_recv_data` in the config."))
|
||||
} else {
|
||||
ProverError::zk(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
|
||||
.await?;
|
||||
@@ -317,7 +254,6 @@ impl Prover<state::Setup> {
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
|
||||
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
@@ -327,9 +263,9 @@ impl Prover<state::Setup> {
|
||||
mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
transcript_refs,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -368,117 +304,24 @@ impl Prover<state::Committed> {
|
||||
mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
transcript_refs,
|
||||
..
|
||||
} = &mut self.state;
|
||||
|
||||
let mut output = ProverOutput {
|
||||
transcript_commitments: Vec::new(),
|
||||
transcript_secrets: Vec::new(),
|
||||
};
|
||||
|
||||
let partial_transcript = if let Some((sent, recv)) = config.reveal() {
|
||||
decode_transcript(vm, sent, recv, transcript_refs).map_err(ProverError::zk)?;
|
||||
|
||||
Some(transcript.to_partial(sent.clone(), recv.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let payload = ProvePayload {
|
||||
handshake: config.server_identity().then(|| {
|
||||
(
|
||||
self.config.server_name().clone(),
|
||||
HandshakeData {
|
||||
certs: tls_transcript
|
||||
.server_cert_chain()
|
||||
.expect("server cert chain is present")
|
||||
.to_vec(),
|
||||
sig: tls_transcript
|
||||
.server_signature()
|
||||
.expect("server signature is present")
|
||||
.clone(),
|
||||
binding: tls_transcript.certificate_binding().clone(),
|
||||
},
|
||||
)
|
||||
}),
|
||||
transcript: partial_transcript,
|
||||
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
|
||||
};
|
||||
|
||||
// Send payload.
|
||||
mux_fut
|
||||
.poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
|
||||
let output = mux_fut
|
||||
.poll_with(prove::prove(
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
self.config.server_name(),
|
||||
transcript,
|
||||
tls_transcript,
|
||||
config,
|
||||
))
|
||||
.await?;
|
||||
|
||||
let mut hash_commitments = None;
|
||||
if let Some(commit_config) = config.transcript_commit() {
|
||||
if commit_config.has_encoding() {
|
||||
let hasher: &(dyn HashAlgorithm + Send + Sync) =
|
||||
match *commit_config.encoding_hash_alg() {
|
||||
HashAlgId::SHA256 => &Sha256::default(),
|
||||
HashAlgId::KECCAK256 => &Keccak256::default(),
|
||||
HashAlgId::BLAKE3 => &Blake3::default(),
|
||||
alg => {
|
||||
return Err(ProverError::config(format!(
|
||||
"unsupported hash algorithm for encoding commitment: {alg}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let (commitment, tree) = mux_fut
|
||||
.poll_with(
|
||||
encoding::receive(
|
||||
ctx,
|
||||
hasher,
|
||||
transcript_refs,
|
||||
|plaintext| vm.get_macs(plaintext).expect("reference is valid"),
|
||||
commit_config.iter_encoding(),
|
||||
)
|
||||
.map_err(ProverError::commit),
|
||||
)
|
||||
.await?;
|
||||
|
||||
output
|
||||
.transcript_commitments
|
||||
.push(TranscriptCommitment::Encoding(commitment));
|
||||
output
|
||||
.transcript_secrets
|
||||
.push(TranscriptSecret::Encoding(tree));
|
||||
}
|
||||
|
||||
if commit_config.has_hash() {
|
||||
hash_commitments = Some(
|
||||
prove_hash(
|
||||
vm,
|
||||
transcript_refs,
|
||||
commit_config
|
||||
.iter_hash()
|
||||
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
|
||||
)
|
||||
.map_err(ProverError::commit)?,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
|
||||
.await?;
|
||||
|
||||
if let Some((hash_fut, hash_secrets)) = hash_commitments {
|
||||
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
|
||||
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
|
||||
output
|
||||
.transcript_commitments
|
||||
.push(TranscriptCommitment::Hash(commitment));
|
||||
output
|
||||
.transcript_secrets
|
||||
.push(TranscriptSecret::Hash(secret));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{error::Error, fmt};
|
||||
|
||||
use mpc_tls::MpcTlsError;
|
||||
|
||||
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
|
||||
use crate::encoding::EncodingError;
|
||||
|
||||
/// Error for [`Prover`](crate::Prover).
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -110,12 +110,6 @@ impl From<MpcTlsError> for ProverError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ZkAesCtrError> for ProverError {
|
||||
fn from(e: ZkAesCtrError) -> Self {
|
||||
Self::new(ErrorKind::Zk, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodingError> for ProverError {
|
||||
fn from(e: EncodingError) -> Self {
|
||||
Self::new(ErrorKind::Commit, e)
|
||||
|
||||
198
crates/tlsn/src/prover/prove.rs
Normal file
198
crates/tlsn/src/prover/prove.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
use mpc_tls::SessionKeys;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::binary::Binary;
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use serio::SinkExt;
|
||||
use tlsn_core::{
|
||||
ProveConfig, ProveRequest, ProverOutput,
|
||||
connection::{HandshakeData, ServerName},
|
||||
transcript::{
|
||||
ContentType, Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
commit::{auth::prove_plaintext, hash::prove_hash, transcript::TranscriptRefs},
|
||||
encoding::{self, MacStore},
|
||||
prover::ProverError,
|
||||
zk_aes_ctr::ZkAesCtr,
|
||||
};
|
||||
|
||||
pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
|
||||
ctx: &mut Context,
|
||||
vm: &mut T,
|
||||
keys: &SessionKeys,
|
||||
server_name: &ServerName,
|
||||
transcript: &Transcript,
|
||||
tls_transcript: &TlsTranscript,
|
||||
config: &ProveConfig,
|
||||
) -> Result<ProverOutput, ProverError> {
|
||||
let mut output = ProverOutput {
|
||||
transcript_commitments: Vec::default(),
|
||||
transcript_secrets: Vec::default(),
|
||||
};
|
||||
|
||||
let request = ProveRequest {
|
||||
handshake: config.server_identity().then(|| {
|
||||
(
|
||||
server_name.clone(),
|
||||
HandshakeData {
|
||||
certs: tls_transcript
|
||||
.server_cert_chain()
|
||||
.expect("server cert chain is present")
|
||||
.to_vec(),
|
||||
sig: tls_transcript
|
||||
.server_signature()
|
||||
.expect("server signature is present")
|
||||
.clone(),
|
||||
binding: tls_transcript.certificate_binding().clone(),
|
||||
},
|
||||
)
|
||||
}),
|
||||
transcript: config
|
||||
.reveal()
|
||||
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone())),
|
||||
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
|
||||
};
|
||||
|
||||
ctx.io_mut()
|
||||
.send(request)
|
||||
.await
|
||||
.map_err(ProverError::from)?;
|
||||
|
||||
let mut auth_sent_ranges = RangeSet::default();
|
||||
let mut auth_recv_ranges = RangeSet::default();
|
||||
|
||||
let (reveal_sent, reveal_recv) = config.reveal().cloned().unwrap_or_default();
|
||||
|
||||
auth_sent_ranges.union_mut(&reveal_sent);
|
||||
auth_recv_ranges.union_mut(&reveal_recv);
|
||||
|
||||
if let Some(commit_config) = config.transcript_commit() {
|
||||
commit_config
|
||||
.iter_hash()
|
||||
.for_each(|((direction, idx), _)| match direction {
|
||||
Direction::Sent => auth_sent_ranges.union_mut(idx),
|
||||
Direction::Received => auth_recv_ranges.union_mut(idx),
|
||||
});
|
||||
|
||||
commit_config
|
||||
.iter_encoding()
|
||||
.for_each(|(direction, idx)| match direction {
|
||||
Direction::Sent => auth_sent_ranges.union_mut(idx),
|
||||
Direction::Received => auth_recv_ranges.union_mut(idx),
|
||||
});
|
||||
}
|
||||
|
||||
let mut zk_aes_sent = ZkAesCtr::new(
|
||||
keys.client_write_key,
|
||||
keys.client_write_iv,
|
||||
tls_transcript
|
||||
.sent()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
);
|
||||
let mut zk_aes_recv = ZkAesCtr::new(
|
||||
keys.server_write_key,
|
||||
keys.server_write_iv,
|
||||
tls_transcript
|
||||
.recv()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
);
|
||||
|
||||
let sent_refs = prove_plaintext(
|
||||
vm,
|
||||
&mut zk_aes_sent,
|
||||
transcript.sent(),
|
||||
&auth_sent_ranges,
|
||||
&reveal_sent,
|
||||
)
|
||||
.map_err(ProverError::commit)?;
|
||||
let recv_refs = prove_plaintext(
|
||||
vm,
|
||||
&mut zk_aes_recv,
|
||||
transcript.received(),
|
||||
&auth_recv_ranges,
|
||||
&reveal_recv,
|
||||
)
|
||||
.map_err(ProverError::commit)?;
|
||||
|
||||
let transcript_refs = TranscriptRefs {
|
||||
sent: sent_refs,
|
||||
recv: recv_refs,
|
||||
};
|
||||
|
||||
let hash_commitments = if let Some(commit_config) = config.transcript_commit()
|
||||
&& commit_config.has_hash()
|
||||
{
|
||||
Some(
|
||||
prove_hash(
|
||||
vm,
|
||||
&transcript_refs,
|
||||
commit_config
|
||||
.iter_hash()
|
||||
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
|
||||
)
|
||||
.map_err(ProverError::commit)?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
|
||||
|
||||
if let Some(commit_config) = config.transcript_commit()
|
||||
&& commit_config.has_encoding()
|
||||
{
|
||||
let mut sent_ranges = RangeSet::default();
|
||||
let mut recv_ranges = RangeSet::default();
|
||||
for (dir, idx) in commit_config.iter_encoding() {
|
||||
match dir {
|
||||
Direction::Sent => sent_ranges.union_mut(idx),
|
||||
Direction::Received => recv_ranges.union_mut(idx),
|
||||
}
|
||||
}
|
||||
|
||||
let sent_map = transcript_refs
|
||||
.sent
|
||||
.index(&sent_ranges)
|
||||
.expect("indices are valid");
|
||||
let recv_map = transcript_refs
|
||||
.recv
|
||||
.index(&recv_ranges)
|
||||
.expect("indices are valid");
|
||||
|
||||
let (commitment, tree) = encoding::receive(
|
||||
ctx,
|
||||
vm,
|
||||
*commit_config.encoding_hash_alg(),
|
||||
&sent_map,
|
||||
&recv_map,
|
||||
commit_config.iter_encoding(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
output
|
||||
.transcript_commitments
|
||||
.push(TranscriptCommitment::Encoding(commitment));
|
||||
output
|
||||
.transcript_secrets
|
||||
.push(TranscriptSecret::Encoding(tree));
|
||||
}
|
||||
|
||||
if let Some((hash_fut, hash_secrets)) = hash_commitments {
|
||||
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
|
||||
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
|
||||
output
|
||||
.transcript_commitments
|
||||
.push(TranscriptCommitment::Hash(commitment));
|
||||
output
|
||||
.transcript_secrets
|
||||
.push(TranscriptSecret::Hash(secret));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -9,10 +9,8 @@ use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
commit::transcript::TranscriptRefs,
|
||||
mux::{MuxControl, MuxFuture},
|
||||
prover::{Mpc, Zk},
|
||||
zk_aes_ctr::ZkAesCtr,
|
||||
};
|
||||
|
||||
/// Entry state
|
||||
@@ -25,8 +23,6 @@ pub struct Setup {
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) mpc_tls: MpcTlsLeader,
|
||||
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
|
||||
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
|
||||
}
|
||||
@@ -39,9 +35,9 @@ pub struct Committed {
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
pub(crate) vm: Zk,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) tls_transcript: TlsTranscript,
|
||||
pub(crate) transcript: Transcript,
|
||||
pub(crate) transcript_refs: TranscriptRefs,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Committed);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
pub(crate) mod config;
|
||||
mod error;
|
||||
pub mod state;
|
||||
mod verify;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -14,18 +15,7 @@ pub use tlsn_core::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Role,
|
||||
commit::{
|
||||
commit_records,
|
||||
hash::verify_hash,
|
||||
transcript::{TranscriptRefs, decode_transcript, verify_transcript},
|
||||
},
|
||||
config::ProtocolConfig,
|
||||
context::build_mt_context,
|
||||
encoding,
|
||||
mux::attach_mux,
|
||||
tag::verify_tags,
|
||||
zk_aes_ctr::ZkAesCtr,
|
||||
Role, config::ProtocolConfig, context::build_mt_context, mux::attach_mux, tag::verify_tags,
|
||||
};
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpc_tls::{MpcTlsFollower, SessionKeys};
|
||||
@@ -35,11 +25,9 @@ use mpz_garble_core::Delta;
|
||||
use mpz_vm_core::prelude::*;
|
||||
use mpz_zk::VerifierConfig as ZkVerifierConfig;
|
||||
use serio::stream::IoStreamExt;
|
||||
use tls_core::msgs::enums::ContentType;
|
||||
use tlsn_core::{
|
||||
ProvePayload,
|
||||
connection::{ConnectionInfo, ServerName},
|
||||
transcript::{TlsTranscript, TranscriptCommitment},
|
||||
transcript::TlsTranscript,
|
||||
};
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -114,23 +102,12 @@ impl Verifier<state::Initialized> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
let delta = Delta::random(&mut rand::rng());
|
||||
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, delta, ctx);
|
||||
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, ctx);
|
||||
|
||||
// Allocate resources for MPC-TLS in the VM.
|
||||
let mut keys = mpc_tls.alloc()?;
|
||||
let vm_lock = vm.try_lock().expect("VM is not locked");
|
||||
translate_keys(&mut keys, &vm_lock)?;
|
||||
|
||||
// Allocate for committing to plaintext.
|
||||
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Verifier);
|
||||
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
|
||||
zk_aes_ctr_sent.alloc(&mut *vm_lock.zk(), protocol_config.max_sent_data())?;
|
||||
|
||||
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Verifier);
|
||||
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
|
||||
zk_aes_ctr_recv.alloc(&mut *vm_lock.zk(), protocol_config.max_recv_data())?;
|
||||
|
||||
drop(vm_lock);
|
||||
|
||||
debug!("setting up mpc-tls");
|
||||
@@ -145,10 +122,7 @@ impl Verifier<state::Initialized> {
|
||||
state: state::Setup {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
delta,
|
||||
mpc_tls,
|
||||
zk_aes_ctr_sent,
|
||||
zk_aes_ctr_recv,
|
||||
keys,
|
||||
vm,
|
||||
},
|
||||
@@ -186,10 +160,7 @@ impl Verifier<state::Setup> {
|
||||
let state::Setup {
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
delta,
|
||||
mpc_tls,
|
||||
mut zk_aes_ctr_sent,
|
||||
mut zk_aes_ctr_recv,
|
||||
vm,
|
||||
keys,
|
||||
} = self.state;
|
||||
@@ -230,27 +201,6 @@ impl Verifier<state::Setup> {
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
|
||||
// Prepare for the prover to prove received plaintext.
|
||||
let (sent_refs, sent_proof) = commit_records(
|
||||
&mut vm,
|
||||
&mut zk_aes_ctr_sent,
|
||||
tls_transcript
|
||||
.sent()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
|
||||
let (recv_refs, recv_proof) = commit_records(
|
||||
&mut vm,
|
||||
&mut zk_aes_ctr_recv,
|
||||
tls_transcript
|
||||
.recv()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
|
||||
.await?;
|
||||
@@ -260,23 +210,16 @@ impl Verifier<state::Setup> {
|
||||
// authenticated from the verifier's perspective.
|
||||
tag_proof.verify().map_err(VerifierError::zk)?;
|
||||
|
||||
// Verify the plaintext proofs.
|
||||
sent_proof.verify().map_err(VerifierError::zk)?;
|
||||
recv_proof.verify().map_err(VerifierError::zk)?;
|
||||
|
||||
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
|
||||
|
||||
Ok(Verifier {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
delta,
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript_refs,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -301,130 +244,34 @@ impl Verifier<state::Committed> {
|
||||
let state::Committed {
|
||||
mux_fut,
|
||||
ctx,
|
||||
delta,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript_refs,
|
||||
..
|
||||
} = &mut self.state;
|
||||
|
||||
let ProvePayload {
|
||||
handshake,
|
||||
transcript,
|
||||
transcript_commit,
|
||||
} = mux_fut
|
||||
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
|
||||
.await?;
|
||||
|
||||
let verifier = if let Some(root_store) = self.config.root_store() {
|
||||
let cert_verifier = if let Some(root_store) = self.config.root_store() {
|
||||
ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
|
||||
} else {
|
||||
ServerCertVerifier::mozilla()
|
||||
};
|
||||
|
||||
let server_name = if let Some((name, cert_data)) = handshake {
|
||||
cert_data
|
||||
.verify(
|
||||
&verifier,
|
||||
tls_transcript.time(),
|
||||
tls_transcript.server_ephemeral_key(),
|
||||
&name,
|
||||
)
|
||||
.map_err(VerifierError::verify)?;
|
||||
|
||||
Some(name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(partial_transcript) = &transcript {
|
||||
let sent_len = tls_transcript
|
||||
.sent()
|
||||
.iter()
|
||||
.filter_map(|record| {
|
||||
if let ContentType::ApplicationData = record.typ {
|
||||
Some(record.ciphertext.len())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.sum::<usize>();
|
||||
|
||||
let recv_len = tls_transcript
|
||||
.recv()
|
||||
.iter()
|
||||
.filter_map(|record| {
|
||||
if let ContentType::ApplicationData = record.typ {
|
||||
Some(record.ciphertext.len())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.sum::<usize>();
|
||||
|
||||
// Check ranges.
|
||||
if partial_transcript.len_sent() != sent_len
|
||||
|| partial_transcript.len_received() != recv_len
|
||||
{
|
||||
return Err(VerifierError::verify(
|
||||
"prover sent transcript with incorrect length",
|
||||
));
|
||||
}
|
||||
|
||||
decode_transcript(
|
||||
vm,
|
||||
partial_transcript.sent_authed(),
|
||||
partial_transcript.received_authed(),
|
||||
transcript_refs,
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
}
|
||||
|
||||
let mut transcript_commitments = Vec::new();
|
||||
let mut hash_commitments = None;
|
||||
if let Some(commit_config) = transcript_commit {
|
||||
if commit_config.encoding() {
|
||||
let commitment = mux_fut
|
||||
.poll_with(encoding::transfer(
|
||||
ctx,
|
||||
transcript_refs,
|
||||
delta,
|
||||
|plaintext| vm.get_keys(plaintext).expect("reference is valid"),
|
||||
))
|
||||
.await?;
|
||||
|
||||
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
|
||||
}
|
||||
|
||||
if commit_config.has_hash() {
|
||||
hash_commitments = Some(
|
||||
verify_hash(vm, transcript_refs, commit_config.iter_hash().cloned())
|
||||
.map_err(VerifierError::verify)?,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(ctx).map_err(VerifierError::zk))
|
||||
let request = mux_fut
|
||||
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
|
||||
.await?;
|
||||
|
||||
// Verify revealed data.
|
||||
if let Some(partial_transcript) = &transcript {
|
||||
verify_transcript(vm, partial_transcript, transcript_refs)
|
||||
.map_err(VerifierError::verify)?;
|
||||
}
|
||||
let output = mux_fut
|
||||
.poll_with(verify::verify(
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
&cert_verifier,
|
||||
tls_transcript,
|
||||
request,
|
||||
))
|
||||
.await?;
|
||||
|
||||
if let Some(hash_commitments) = hash_commitments {
|
||||
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
|
||||
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(VerifierOutput {
|
||||
server_name,
|
||||
transcript,
|
||||
transcript_commitments,
|
||||
})
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Closes the connection with the prover.
|
||||
@@ -447,11 +294,11 @@ impl Verifier<state::Committed> {
|
||||
fn build_mpc_tls(
|
||||
config: &VerifierConfig,
|
||||
protocol_config: &ProtocolConfig,
|
||||
delta: Delta,
|
||||
ctx: Context,
|
||||
) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsFollower) {
|
||||
let mut rng = rand::rng();
|
||||
|
||||
let delta = Delta::random(&mut rng);
|
||||
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
|
||||
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
|
||||
let rcot_send = mpz_ot::kos::Sender::new(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
|
||||
use crate::encoding::EncodingError;
|
||||
use mpc_tls::MpcTlsError;
|
||||
use std::{error::Error, fmt};
|
||||
|
||||
@@ -110,12 +110,6 @@ impl From<MpcTlsError> for VerifierError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ZkAesCtrError> for VerifierError {
|
||||
fn from(e: ZkAesCtrError) -> Self {
|
||||
Self::new(ErrorKind::Zk, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodingError> for VerifierError {
|
||||
fn from(e: EncodingError) -> Self {
|
||||
Self::new(ErrorKind::Commit, e)
|
||||
|
||||
@@ -2,14 +2,9 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
commit::transcript::TranscriptRefs,
|
||||
mux::{MuxControl, MuxFuture},
|
||||
zk_aes_ctr::ZkAesCtr,
|
||||
};
|
||||
use crate::mux::{MuxControl, MuxFuture};
|
||||
use mpc_tls::{MpcTlsFollower, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::correlated::Delta;
|
||||
use tlsn_core::transcript::TlsTranscript;
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -28,10 +23,7 @@ opaque_debug::implement!(Initialized);
|
||||
pub struct Setup {
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) delta: Delta,
|
||||
pub(crate) mpc_tls: MpcTlsFollower,
|
||||
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
|
||||
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
|
||||
}
|
||||
@@ -40,11 +32,10 @@ pub struct Setup {
|
||||
pub struct Committed {
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) delta: Delta,
|
||||
pub(crate) ctx: Context,
|
||||
pub(crate) vm: Zk,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) tls_transcript: TlsTranscript,
|
||||
pub(crate) transcript_refs: TranscriptRefs,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Committed);
|
||||
|
||||
183
crates/tlsn/src/verifier/verify.rs
Normal file
183
crates/tlsn/src/verifier/verify.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use mpc_tls::SessionKeys;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::binary::Binary;
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use tlsn_core::{
|
||||
ProveRequest, VerifierOutput,
|
||||
transcript::{
|
||||
ContentType, Direction, PartialTranscript, Record, TlsTranscript, TranscriptCommitment,
|
||||
},
|
||||
webpki::ServerCertVerifier,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
commit::{auth::verify_plaintext, hash::verify_hash, transcript::TranscriptRefs},
|
||||
encoding::{self, KeyStore},
|
||||
verifier::VerifierError,
|
||||
zk_aes_ctr::ZkAesCtr,
|
||||
};
|
||||
|
||||
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
ctx: &mut Context,
|
||||
vm: &mut T,
|
||||
keys: &SessionKeys,
|
||||
cert_verifier: &ServerCertVerifier,
|
||||
tls_transcript: &TlsTranscript,
|
||||
request: ProveRequest,
|
||||
) -> Result<VerifierOutput, VerifierError> {
|
||||
let ProveRequest {
|
||||
handshake,
|
||||
transcript,
|
||||
transcript_commit,
|
||||
} = request;
|
||||
|
||||
let ciphertext_sent = collect_ciphertext(tls_transcript.sent());
|
||||
let ciphertext_recv = collect_ciphertext(tls_transcript.recv());
|
||||
|
||||
let has_reveal = transcript.is_some();
|
||||
let transcript = if let Some(transcript) = transcript {
|
||||
if transcript.len_sent() != ciphertext_sent.len()
|
||||
|| transcript.len_received() != ciphertext_recv.len()
|
||||
{
|
||||
return Err(VerifierError::verify(
|
||||
"prover sent transcript with incorrect length",
|
||||
));
|
||||
}
|
||||
|
||||
transcript
|
||||
} else {
|
||||
PartialTranscript::new(ciphertext_sent.len(), ciphertext_recv.len())
|
||||
};
|
||||
|
||||
let server_name = if let Some((name, cert_data)) = handshake {
|
||||
cert_data
|
||||
.verify(
|
||||
cert_verifier,
|
||||
tls_transcript.time(),
|
||||
tls_transcript.server_ephemeral_key(),
|
||||
&name,
|
||||
)
|
||||
.map_err(VerifierError::verify)?;
|
||||
|
||||
Some(name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut auth_sent_ranges = RangeSet::default();
|
||||
let mut auth_recv_ranges = RangeSet::default();
|
||||
|
||||
auth_sent_ranges.union_mut(transcript.sent_authed());
|
||||
auth_recv_ranges.union_mut(transcript.received_authed());
|
||||
|
||||
if let Some(commit_config) = transcript_commit.as_ref() {
|
||||
commit_config
|
||||
.iter_hash()
|
||||
.for_each(|(direction, idx, _)| match direction {
|
||||
Direction::Sent => auth_sent_ranges.union_mut(idx),
|
||||
Direction::Received => auth_recv_ranges.union_mut(idx),
|
||||
});
|
||||
|
||||
if let Some((sent, recv)) = commit_config.encoding() {
|
||||
auth_sent_ranges.union_mut(sent);
|
||||
auth_recv_ranges.union_mut(recv);
|
||||
}
|
||||
}
|
||||
|
||||
let mut zk_aes_sent = ZkAesCtr::new(
|
||||
keys.client_write_key,
|
||||
keys.client_write_iv,
|
||||
tls_transcript
|
||||
.sent()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
);
|
||||
let mut zk_aes_recv = ZkAesCtr::new(
|
||||
keys.server_write_key,
|
||||
keys.server_write_iv,
|
||||
tls_transcript
|
||||
.recv()
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
);
|
||||
|
||||
let (sent_refs, sent_proof) = verify_plaintext(
|
||||
vm,
|
||||
&mut zk_aes_sent,
|
||||
transcript.sent_unsafe(),
|
||||
&ciphertext_sent,
|
||||
&auth_sent_ranges,
|
||||
transcript.sent_authed(),
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
let (recv_refs, recv_proof) = verify_plaintext(
|
||||
vm,
|
||||
&mut zk_aes_recv,
|
||||
transcript.received_unsafe(),
|
||||
&ciphertext_recv,
|
||||
&auth_recv_ranges,
|
||||
transcript.received_authed(),
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
|
||||
let transcript_refs = TranscriptRefs {
|
||||
sent: sent_refs,
|
||||
recv: recv_refs,
|
||||
};
|
||||
|
||||
let mut transcript_commitments = Vec::new();
|
||||
let mut hash_commitments = None;
|
||||
if let Some(commit_config) = transcript_commit.as_ref()
|
||||
&& commit_config.has_hash()
|
||||
{
|
||||
hash_commitments = Some(
|
||||
verify_hash(vm, &transcript_refs, commit_config.iter_hash().cloned())
|
||||
.map_err(VerifierError::verify)?,
|
||||
);
|
||||
}
|
||||
|
||||
vm.execute_all(ctx).await.map_err(VerifierError::zk)?;
|
||||
|
||||
sent_proof.verify().map_err(VerifierError::verify)?;
|
||||
recv_proof.verify().map_err(VerifierError::verify)?;
|
||||
|
||||
if let Some(commit_config) = transcript_commit
|
||||
&& let Some((sent, recv)) = commit_config.encoding()
|
||||
{
|
||||
let sent_map = transcript_refs
|
||||
.sent
|
||||
.index(sent)
|
||||
.expect("ranges were authenticated");
|
||||
let recv_map = transcript_refs
|
||||
.recv
|
||||
.index(recv)
|
||||
.expect("ranges were authenticated");
|
||||
|
||||
let commitment = encoding::transfer(ctx, vm, &sent_map, &recv_map).await?;
|
||||
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
|
||||
}
|
||||
|
||||
if let Some(hash_commitments) = hash_commitments {
|
||||
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
|
||||
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(VerifierOutput {
|
||||
server_name,
|
||||
transcript: has_reveal.then_some(transcript),
|
||||
transcript_commitments,
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_ciphertext<'a>(records: impl IntoIterator<Item = &'a Record>) -> Vec<u8> {
|
||||
let mut ciphertext = Vec::new();
|
||||
records
|
||||
.into_iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData)
|
||||
.for_each(|record| {
|
||||
ciphertext.extend_from_slice(&record.ciphertext);
|
||||
});
|
||||
ciphertext
|
||||
}
|
||||
@@ -1,181 +1,226 @@
|
||||
//! Zero-knowledge AES-CTR encryption.
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use cipher::{
|
||||
Cipher, CipherError, Keystream,
|
||||
aes::{Aes128, AesError},
|
||||
};
|
||||
use mpz_circuits::circuits::{AES128, xor};
|
||||
use mpz_memory_core::{
|
||||
Array, Vector,
|
||||
Array, MemoryExt, Vector, ViewExt,
|
||||
binary::{Binary, U8},
|
||||
};
|
||||
use mpz_vm_core::{Vm, prelude::*};
|
||||
use mpz_vm_core::{Call, CallableExt, Vm};
|
||||
use rangeset::RangeSet;
|
||||
use tlsn_core::transcript::Record;
|
||||
|
||||
use crate::Role;
|
||||
|
||||
type Nonce = Array<U8, 8>;
|
||||
type Ctr = Array<U8, 4>;
|
||||
type Block = Array<U8, 16>;
|
||||
|
||||
const START_CTR: u32 = 2;
|
||||
use crate::commit::transcript::ReferenceMap;
|
||||
|
||||
/// ZK AES-CTR encryption.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ZkAesCtr {
|
||||
role: Role,
|
||||
aes: Aes128,
|
||||
state: State,
|
||||
key: Array<U8, 16>,
|
||||
iv: Array<U8, 4>,
|
||||
records: Vec<(usize, RecordState)>,
|
||||
total_len: usize,
|
||||
}
|
||||
|
||||
impl ZkAesCtr {
|
||||
/// Creates a new ZK AES-CTR instance.
|
||||
pub(crate) fn new(role: Role) -> Self {
|
||||
/// Creates a new instance.
|
||||
pub(crate) fn new<'record>(
|
||||
key: Array<U8, 16>,
|
||||
iv: Array<U8, 4>,
|
||||
records: impl IntoIterator<Item = &'record Record>,
|
||||
) -> Self {
|
||||
let mut pos = 0;
|
||||
let mut record_state = Vec::new();
|
||||
for record in records {
|
||||
record_state.push((
|
||||
pos,
|
||||
RecordState {
|
||||
explicit_nonce: Some(record.explicit_nonce.clone()),
|
||||
explicit_nonce_ref: None,
|
||||
range: pos..pos + record.ciphertext.len(),
|
||||
},
|
||||
));
|
||||
pos += record.ciphertext.len();
|
||||
}
|
||||
|
||||
Self {
|
||||
role,
|
||||
aes: Aes128::default(),
|
||||
state: State::Init,
|
||||
key,
|
||||
iv,
|
||||
records: record_state,
|
||||
total_len: pos,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the role.
|
||||
pub(crate) fn role(&self) -> &Role {
|
||||
&self.role
|
||||
}
|
||||
|
||||
/// Allocates `len` bytes for encryption.
|
||||
pub(crate) fn alloc(
|
||||
/// Allocates the plaintext for the provided ranges.
|
||||
///
|
||||
/// Returns a reference to the plaintext and the ciphertext.
|
||||
pub(crate) fn alloc_plaintext(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
len: usize,
|
||||
) -> Result<(), ZkAesCtrError> {
|
||||
let State::Init = self.state.take() else {
|
||||
Err(ErrorRepr::State {
|
||||
reason: "must be in init state to allocate",
|
||||
})?
|
||||
};
|
||||
ranges: &RangeSet<usize>,
|
||||
) -> Result<(ReferenceMap, ReferenceMap), ZkAesCtrError> {
|
||||
let len = ranges.len();
|
||||
|
||||
// Round up to the nearest block size.
|
||||
let len = 16 * len.div_ceil(16);
|
||||
|
||||
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
|
||||
let keystream = self.aes.alloc_keystream(vm, len)?;
|
||||
|
||||
match self.role {
|
||||
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
|
||||
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
|
||||
if len > self.total_len {
|
||||
return Err(ZkAesCtrError(ErrorRepr::TranscriptBounds {
|
||||
len,
|
||||
max: self.total_len,
|
||||
}));
|
||||
}
|
||||
|
||||
self.state = State::Ready { input, keystream };
|
||||
let plaintext = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
|
||||
let keystream = self.alloc_keystream(vm, ranges)?;
|
||||
|
||||
Ok(())
|
||||
let mut builder = Call::builder(Arc::new(xor(len * 8))).arg(plaintext);
|
||||
for slice in keystream {
|
||||
builder = builder.arg(slice);
|
||||
}
|
||||
let call = builder.build().expect("call should be valid");
|
||||
|
||||
let ciphertext: Vector<U8> = vm.call(call).map_err(ZkAesCtrError::vm)?;
|
||||
|
||||
let mut pos = 0;
|
||||
let plaintext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
|
||||
let chunk = plaintext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
}));
|
||||
|
||||
let mut pos = 0;
|
||||
let ciphertext = ReferenceMap::from_iter(ranges.iter_ranges().map(move |range| {
|
||||
let chunk = ciphertext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
}));
|
||||
|
||||
Ok((plaintext, ciphertext))
|
||||
}
|
||||
|
||||
/// Sets the key and IV for the cipher.
|
||||
pub(crate) fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
|
||||
self.aes.set_key(key);
|
||||
self.aes.set_iv(iv);
|
||||
}
|
||||
|
||||
/// Proves the encryption of `len` bytes.
|
||||
///
|
||||
/// Here we only assign certain values in the VM but the actual proving
|
||||
/// happens later when the plaintext is assigned and the VM is executed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `explicit_nonce` - Explicit nonce.
|
||||
/// * `len` - Length of the plaintext in bytes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A VM reference to the plaintext and the ciphertext.
|
||||
pub(crate) fn encrypt(
|
||||
fn alloc_keystream(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
|
||||
let State::Ready { input, keystream } = &mut self.state else {
|
||||
Err(ErrorRepr::State {
|
||||
reason: "must be in ready state to encrypt",
|
||||
})?
|
||||
};
|
||||
ranges: &RangeSet<usize>,
|
||||
) -> Result<Vec<Vector<U8>>, ZkAesCtrError> {
|
||||
let mut keystream = Vec::new();
|
||||
|
||||
let explicit_nonce: [u8; 8] =
|
||||
explicit_nonce
|
||||
.try_into()
|
||||
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength {
|
||||
expected: 8,
|
||||
actual: explicit_nonce.len(),
|
||||
})?;
|
||||
let mut range_iter = ranges.iter_ranges();
|
||||
let mut current_range = range_iter.next();
|
||||
for (pos, record) in self.records.iter_mut() {
|
||||
let pos = *pos;
|
||||
let mut current_block = None;
|
||||
loop {
|
||||
let Some(range) = current_range.take().or_else(|| range_iter.next()) else {
|
||||
return Ok(keystream);
|
||||
};
|
||||
|
||||
let block_count = len.div_ceil(16);
|
||||
let padded_len = block_count * 16;
|
||||
let padding_len = padded_len - len;
|
||||
if range.start >= record.range.end {
|
||||
current_range = Some(range);
|
||||
break;
|
||||
}
|
||||
|
||||
if padded_len > input.len() {
|
||||
Err(ErrorRepr::InsufficientPreprocessing {
|
||||
expected: padded_len,
|
||||
actual: input.len(),
|
||||
})?
|
||||
}
|
||||
const BLOCK_SIZE: usize = 16;
|
||||
let block_num = (range.start - pos) / BLOCK_SIZE;
|
||||
let block = if let Some((current_block_num, block)) = current_block.take()
|
||||
&& current_block_num == block_num
|
||||
{
|
||||
block
|
||||
} else {
|
||||
let block = record.alloc_block(vm, self.key, self.iv, block_num)?;
|
||||
|
||||
let mut input = input.split_off(input.len() - padded_len);
|
||||
let keystream = keystream.consume(padded_len)?;
|
||||
let mut output = keystream.apply(vm, input)?;
|
||||
current_block = Some((block_num, block));
|
||||
|
||||
// Assign counter block inputs.
|
||||
let mut ctr = START_CTR..;
|
||||
keystream.assign(vm, explicit_nonce, move || {
|
||||
ctr.next().expect("range is unbounded").to_be_bytes()
|
||||
})?;
|
||||
block
|
||||
};
|
||||
|
||||
// Assign zeroes to the padding.
|
||||
if padding_len > 0 {
|
||||
let padding = input.split_off(input.len() - padding_len);
|
||||
// To simplify the impl, we don't mark the padding as public, that's why only
|
||||
// the prover assigns it.
|
||||
if let Role::Prover = self.role {
|
||||
vm.assign(padding, vec![0; padding_len])
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
let start = (range.start - pos) % BLOCK_SIZE;
|
||||
let end = (start + range.len()).min(BLOCK_SIZE);
|
||||
let len = end - start;
|
||||
|
||||
keystream.push(block.get(start..end).expect("range is checked"));
|
||||
|
||||
// If the range is larger than a block, process the tail.
|
||||
if range.len() > BLOCK_SIZE {
|
||||
current_range = Some(range.start + len..range.end);
|
||||
}
|
||||
}
|
||||
vm.commit(padding).map_err(ZkAesCtrError::vm)?;
|
||||
output.truncate(len);
|
||||
}
|
||||
|
||||
Ok((input, output))
|
||||
unreachable!("plaintext length was checked");
|
||||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
Init,
|
||||
Ready {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
},
|
||||
Error,
|
||||
#[derive(Debug)]
|
||||
struct RecordState {
|
||||
explicit_nonce: Option<Vec<u8>>,
|
||||
range: Range<usize>,
|
||||
explicit_nonce_ref: Option<Vector<U8>>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
impl RecordState {
|
||||
fn alloc_explicit_nonce(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<Vector<U8>, ZkAesCtrError> {
|
||||
if let Some(explicit_nonce) = self.explicit_nonce_ref {
|
||||
Ok(explicit_nonce)
|
||||
} else {
|
||||
const EXPLICIT_NONCE_LEN: usize = 8;
|
||||
let explicit_nonce_ref = vm
|
||||
.alloc_vec::<U8>(EXPLICIT_NONCE_LEN)
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
vm.mark_public(explicit_nonce_ref)
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
vm.assign(
|
||||
explicit_nonce_ref,
|
||||
self.explicit_nonce
|
||||
.take()
|
||||
.expect("explicit nonce only set once"),
|
||||
)
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
vm.commit(explicit_nonce_ref).map_err(ZkAesCtrError::vm)?;
|
||||
|
||||
impl std::fmt::Debug for State {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
State::Init => write!(f, "Init"),
|
||||
State::Ready { .. } => write!(f, "Ready"),
|
||||
State::Error => write!(f, "Error"),
|
||||
self.explicit_nonce_ref = Some(explicit_nonce_ref);
|
||||
Ok(explicit_nonce_ref)
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_block(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
key: Array<U8, 16>,
|
||||
iv: Array<U8, 4>,
|
||||
block: usize,
|
||||
) -> Result<Vector<U8>, ZkAesCtrError> {
|
||||
let explicit_nonce = self.alloc_explicit_nonce(vm)?;
|
||||
let ctr: Array<U8, 4> = vm.alloc().map_err(ZkAesCtrError::vm)?;
|
||||
vm.mark_public(ctr).map_err(ZkAesCtrError::vm)?;
|
||||
const START_CTR: u32 = 2;
|
||||
vm.assign(ctr, (START_CTR + block as u32).to_be_bytes())
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
vm.commit(ctr).map_err(ZkAesCtrError::vm)?;
|
||||
|
||||
let block: Array<U8, 16> = vm
|
||||
.call(
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(iv)
|
||||
.arg(explicit_nonce)
|
||||
.arg(ctr)
|
||||
.build()
|
||||
.expect("call should be valid"),
|
||||
)
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
|
||||
Ok(Vector::from(block))
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`ZkAesCtr`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ZkAesCtrError(#[from] ErrorRepr);
|
||||
pub(crate) struct ZkAesCtrError(#[from] ErrorRepr);
|
||||
|
||||
impl ZkAesCtrError {
|
||||
fn vm<E>(err: E) -> Self
|
||||
@@ -184,35 +229,13 @@ impl ZkAesCtrError {
|
||||
{
|
||||
Self(ErrorRepr::Vm(err.into()))
|
||||
}
|
||||
|
||||
pub fn is_insufficient(&self) -> bool {
|
||||
matches!(self.0, ErrorRepr::InsufficientPreprocessing { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("zk aes error")]
|
||||
enum ErrorRepr {
|
||||
#[error("invalid state: {reason}")]
|
||||
State { reason: &'static str },
|
||||
#[error("cipher error: {0}")]
|
||||
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("vm error: {0}")]
|
||||
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
|
||||
ExplicitNonceLength { expected: usize, actual: usize },
|
||||
#[error("insufficient preprocessing: expected {expected}, got {actual}")]
|
||||
InsufficientPreprocessing { expected: usize, actual: usize },
|
||||
}
|
||||
|
||||
impl From<AesError> for ZkAesCtrError {
|
||||
fn from(err: AesError) -> Self {
|
||||
Self(ErrorRepr::Cipher(Box::new(err)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CipherError> for ZkAesCtrError {
|
||||
fn from(err: CipherError) -> Self {
|
||||
Self(ErrorRepr::Cipher(Box::new(err)))
|
||||
}
|
||||
#[error("transcript bounds exceeded: {len} > {max}")]
|
||||
TranscriptBounds { len: usize, max: usize },
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use rangeset::RangeSet;
|
||||
use tlsn::{
|
||||
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
|
||||
connection::ServerName,
|
||||
hash::{HashAlgId, HashProvider},
|
||||
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
|
||||
transcript::{TranscriptCommitConfig, TranscriptCommitment},
|
||||
transcript::{
|
||||
Direction, TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind,
|
||||
TranscriptSecret,
|
||||
},
|
||||
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
|
||||
};
|
||||
use tlsn_server_fixture::bind;
|
||||
@@ -86,9 +91,25 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
|
||||
|
||||
let mut builder = TranscriptCommitConfig::builder(prover.transcript());
|
||||
|
||||
// Commit to everything
|
||||
builder.commit_sent(&(0..sent_tx_len)).unwrap();
|
||||
builder.commit_recv(&(0..recv_tx_len)).unwrap();
|
||||
for kind in [
|
||||
TranscriptCommitmentKind::Encoding,
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256,
|
||||
},
|
||||
] {
|
||||
builder
|
||||
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let transcript_commit = builder.build().unwrap();
|
||||
|
||||
@@ -102,9 +123,52 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
|
||||
builder.transcript_commit(transcript_commit);
|
||||
|
||||
let config = builder.build().unwrap();
|
||||
|
||||
prover.prove(&config).await.unwrap();
|
||||
let transcript = prover.transcript().clone();
|
||||
let output = prover.prove(&config).await.unwrap();
|
||||
prover.close().await.unwrap();
|
||||
|
||||
let encoding_tree = output
|
||||
.transcript_secrets
|
||||
.iter()
|
||||
.find_map(|secret| {
|
||||
if let TranscriptSecret::Encoding(tree) = secret {
|
||||
Some(tree)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let encoding_commitment = output
|
||||
.transcript_commitments
|
||||
.iter()
|
||||
.find_map(|commitment| {
|
||||
if let TranscriptCommitment::Encoding(commitment) = commitment {
|
||||
Some(commitment)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let prove_sent = RangeSet::from(1..sent_tx_len - 1);
|
||||
let prove_recv = RangeSet::from(1..recv_tx_len - 1);
|
||||
let idxs = [
|
||||
(Direction::Sent, prove_sent.clone()),
|
||||
(Direction::Received, prove_recv.clone()),
|
||||
];
|
||||
let proof = encoding_tree.proof(idxs.iter()).unwrap();
|
||||
let (auth_sent, auth_recv) = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
encoding_commitment,
|
||||
transcript.sent(),
|
||||
transcript.received(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth_sent, prove_sent);
|
||||
assert_eq!(auth_recv, prove_recv);
|
||||
}
|
||||
|
||||
#[instrument(skip(socket))]
|
||||
@@ -125,14 +189,21 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(soc
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let mut verifier = verifier
|
||||
.setup(socket.compat())
|
||||
.await
|
||||
.unwrap()
|
||||
.run()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let VerifierOutput {
|
||||
server_name,
|
||||
transcript,
|
||||
transcript_commitments,
|
||||
} = verifier
|
||||
.verify(socket.compat(), &VerifyConfig::default())
|
||||
.await
|
||||
.unwrap();
|
||||
} = verifier.verify(&VerifyConfig::default()).await.unwrap();
|
||||
|
||||
verifier.close().await.unwrap();
|
||||
|
||||
let transcript = transcript.unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user