Compare commits

...

10 Commits

Author SHA1 Message Date
sinu
a56f2dd02b rename ProvePayload to ProveRequest 2025-09-11 12:51:23 -07:00
sinu
38a1ec3f72 import order 2025-09-11 09:43:38 -07:00
sinu
1501bc661f consolidate encoding stuff 2025-09-11 09:42:03 -07:00
sinu
0e2c2cb045 revert additional derives 2025-09-11 09:37:44 -07:00
sinu
a662fb7511 fix import order 2025-09-11 09:32:50 -07:00
th4s
be2e1ab95a adapt tests to new base 2025-09-11 18:15:14 +02:00
th4s
d34d135bfe adapt to new base 2025-09-11 10:55:26 +02:00
th4s
091d26bb63 fmt nightly 2025-09-11 10:13:08 +02:00
th4s
f031fe9a8d allow encoding protocol being executed only once 2025-09-11 10:13:08 +02:00
th4s
12c9a5eb34 refactor(tlsn): improve proving flow prover and verifier
- add `TlsTranscript` fixture for testing
- refactor and simplify the commit flow
- change how the TLS transcript is committed
  - tag verification is still done for the whole transcript
  - plaintext authentication is only done where needed
  - encoding adjustments are only transmitted for needed ranges
  - makes `prove` and `verify` functions more efficient and callable more than once
- decoding now supports server-write-key decoding without authentication
- add more tests
2025-09-11 10:13:06 +02:00
31 changed files with 3744 additions and 989 deletions

23
Cargo.lock generated
View File

@@ -3604,6 +3604,16 @@ version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]]
name = "lipsum"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "636860251af8963cc40f6b4baadee105f02e21b28131d76eba8e40ce84ab8064"
dependencies = [
"rand 0.8.5",
"rand_chacha 0.3.1",
]
[[package]]
name = "litemap"
version = "0.8.0"
@@ -5136,9 +5146,9 @@ dependencies = [
[[package]]
name = "rustls-webpki"
version = "0.103.4"
version = "0.103.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc"
checksum = "b5a37813727b78798e53c2bec3f5e8fe12a6d6f8389bf9ca7802add4c9905ad8"
dependencies = [
"aws-lc-rs",
"ring 0.17.14",
@@ -5941,11 +5951,15 @@ dependencies = [
name = "tlsn"
version = "0.1.0-alpha.13-pre"
dependencies = [
"aes 0.8.4",
"cipher 0.4.4",
"ctr 0.9.2",
"derive_builder 0.12.0",
"futures",
"ghash 0.5.1",
"http-body-util",
"hyper",
"lipsum",
"mpz-common",
"mpz-core",
"mpz-garble",
@@ -5966,6 +5980,7 @@ dependencies = [
"semver 1.0.26",
"serde",
"serio",
"sha2",
"thiserror 1.0.69",
"tlsn-attestation",
"tlsn-cipher",
@@ -6798,9 +6813,9 @@ checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-id-start"
version = "1.3.1"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f322b60f6b9736017344fa0635d64be2f458fbc04eef65f6be22976dd1ffd5b"
checksum = "81b79ad29b5e19de4260020f8919b443b2ef0277d242ce532ec7b7a2cc8b6007"
[[package]]
name = "unicode-ident"

View File

@@ -100,7 +100,7 @@ bytes = { version = "1.4" }
cfg-if = { version = "1" }
chromiumoxide = { version = "0.7" }
chrono = { version = "0.4" }
cipher = { version = "0.4" }
cipher-crypto = { package = "cipher", version = "0.4" }
clap = { version = "4.5" }
criterion = { version = "0.5" }
ctr = { version = "0.9" }
@@ -123,6 +123,7 @@ inventory = { version = "0.3" }
itybity = { version = "0.2" }
js-sys = { version = "0.3" }
k256 = { version = "0.13" }
lipsum = { version = "0.9" }
log = { version = "0.4" }
once_cell = { version = "1.19" }
opaque-debug = { version = "0.3" }

View File

@@ -31,4 +31,4 @@ mpz-ot = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }
cipher-crypto = { workspace = true }

View File

@@ -344,8 +344,8 @@ mod tests {
start_ctr: usize,
msg: Vec<u8>,
) -> Vec<u8> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let mut full_iv = [0u8; 16];
@@ -365,7 +365,7 @@ mod tests {
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
use ::aes::Aes128 as TestAes128;
use ::cipher::{BlockEncrypt, KeyInit};
use cipher_crypto::{BlockEncrypt, KeyInit};
let mut msg = msg.into();
let cipher = TestAes128::new(&key.into());

View File

@@ -6,7 +6,10 @@ use rustls_pki_types as webpki_types;
use serde::{Deserialize, Serialize};
use tls_core::msgs::{codec::Codec, enums::NamedGroup, handshake::ServerECDHParams};
use crate::webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError};
use crate::{
transcript::TlsTranscript,
webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError},
};
/// TLS version.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -312,6 +315,25 @@ pub struct HandshakeData {
}
impl HandshakeData {
/// Creates a new instance.
///
/// # Arguments
///
/// * `transcript` - The TLS transcript.
pub fn new(transcript: &TlsTranscript) -> Self {
Self {
certs: transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: transcript.certificate_binding().clone(),
}
}
/// Verifies the handshake data.
///
/// # Arguments

View File

@@ -95,7 +95,7 @@ impl Display for HashAlgId {
}
/// A typed hash value.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TypedHash {
/// The algorithm of the hash.
pub alg: HashAlgId,

View File

@@ -130,6 +130,15 @@ impl<'a> ProveConfigBuilder<'a> {
self.reveal(Direction::Received, ranges)
}
/// Reveals the full transcript range for a given direction.
pub fn reveal_all(
&mut self,
direction: Direction,
) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(direction);
self.reveal(direction, &(0..len))
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigBuilderError> {
Ok(ProveConfig {
@@ -190,10 +199,10 @@ pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum VerifyConfigBuilderErrorRepr {}
/// Payload sent to the verifier.
/// Request to prove statements about the connection.
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvePayload {
pub struct ProveRequest {
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.
@@ -202,6 +211,29 @@ pub struct ProvePayload {
pub transcript_commit: Option<TranscriptCommitRequest>,
}
impl ProveRequest {
/// Creates a new prove payload.
///
/// # Arguments
///
/// * `config` - The prove config.
/// * `transcript` - The partial transcript.
/// * `handshake` - The server name and handshake data.
pub fn new(
config: &ProveConfig,
transcript: Option<PartialTranscript>,
handshake: Option<(ServerName, HandshakeData)>,
) -> Self {
let transcript_commit = config.transcript_commit().map(|config| config.to_request());
Self {
handshake,
transcript,
transcript_commit,
}
}
}
/// Prover output.
#[derive(Serialize, Deserialize)]
pub struct ProverOutput {

View File

@@ -1,6 +1,6 @@
//! Transcript commitments.
use std::{collections::HashSet, fmt};
use std::fmt;
use rangeset::ToRangeSet;
use serde::{Deserialize, Serialize};
@@ -69,8 +69,6 @@ pub enum TranscriptSecret {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitConfig {
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
@@ -85,16 +83,6 @@ impl TranscriptCommitConfig {
&self.encoding_hash_alg
}
/// Returns `true` if the configuration has any encoding commitments.
pub fn has_encoding(&self) -> bool {
self.has_encoding
}
/// Returns `true` if the configuration has any hash commitments.
pub fn has_hash(&self) -> bool {
self.has_hash
}
/// Returns an iterator over the encoding commitment indices.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.commits.iter().filter_map(|(idx, kind)| match kind {
@@ -114,7 +102,10 @@ impl TranscriptCommitConfig {
/// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest {
encoding: self.has_encoding,
encoding: self
.iter_encoding()
.map(|(dir, idx)| (*dir, idx.clone()))
.collect(),
hash: self
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -131,10 +122,8 @@ impl TranscriptCommitConfig {
pub struct TranscriptCommitConfigBuilder<'a> {
transcript: &'a Transcript,
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
default_kind: TranscriptCommitmentKind,
commits: HashSet<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
impl<'a> TranscriptCommitConfigBuilder<'a> {
@@ -143,10 +132,8 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
Self {
transcript,
encoding_hash_alg: HashAlgId::BLAKE3,
has_encoding: false,
has_hash: false,
default_kind: TranscriptCommitmentKind::Encoding,
commits: HashSet::default(),
commits: Vec::default(),
}
}
@@ -188,14 +175,12 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
),
));
}
let value = ((direction, idx), kind);
match kind {
TranscriptCommitmentKind::Encoding => self.has_encoding = true,
TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
if !self.commits.contains(&value) {
self.commits.push(value);
}
self.commits.insert(((direction, idx), kind));
Ok(self)
}
@@ -241,8 +226,6 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
Ok(TranscriptCommitConfig {
encoding_hash_alg: self.encoding_hash_alg,
has_encoding: self.has_encoding,
has_hash: self.has_hash,
commits: Vec::from_iter(self.commits),
})
}
@@ -289,19 +272,14 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest {
encoding: bool,
encoding: Vec<(Direction, RangeSet<usize>)>,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
}
impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested.
pub fn encoding(&self) -> bool {
self.encoding
}
/// Returns `true` if a hash commitment is requested.
pub fn has_hash(&self) -> bool {
!self.hash.is_empty()
/// Returns an iterator over the encoding commitments.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.encoding.iter()
}
/// Returns an iterator over the hash commitments.

View File

@@ -12,7 +12,7 @@ const BIT_ENCODING_SIZE: usize = 16;
const BYTE_ENCODING_SIZE: usize = 128;
/// Secret used by an encoder to generate encodings.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncoderSecret {
seed: [u8; 32],
delta: [u8; BIT_ENCODING_SIZE],

View File

@@ -255,7 +255,7 @@ async fn notarize(
transcript_commitments,
transcript_secrets,
..
} = prover.prove(&disclosure_config).await?;
} = prover.prove(disclosure_config).await?;
// Build an attestation request.
let mut builder = AttestationRequest::builder(config);

View File

@@ -174,7 +174,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -93,7 +93,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let config = builder.build()?;
prover.prove(&config).await?;
prover.prove(config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -107,7 +107,7 @@ async fn prover(provider: &IoProvider) {
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -40,6 +40,9 @@ mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
aes = { workspace = true }
cipher-crypto = { workspace = true }
ctr = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
opaque-debug = { workspace = true }
@@ -57,6 +60,8 @@ rangeset = { workspace = true }
webpki-roots = { workspace = true }
[dev-dependencies]
lipsum = { workspace = true }
sha2 = { workspace = true }
rstest = { workspace = true }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
@@ -65,3 +70,5 @@ tokio-util = { workspace = true, features = ["compat"] }
hyper = { workspace = true, features = ["client"] }
http-body-util = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tlsn-core = { workspace = true, features = ["fixtures"] }
mpz-ot = { workspace = true, features = ["ideal"] }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,708 @@
//! Authentication of the transcript plaintext and creation of the transcript
//! references.
use std::ops::Range;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{DecodeError, DecodeFutureTyped, MemoryExt, binary::Binary};
use mpz_vm_core::Vm;
use rangeset::{Disjoint, RangeSet, Union, UnionMut};
use tlsn_core::{
hash::HashAlgId,
transcript::{ContentType, Direction, PartialTranscript, Record, TlsTranscript},
};
use crate::{
Role,
commit::transcript::TranscriptRefs,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
/// Transcript Authenticator.
pub(crate) struct Authenticator {
encoding: Index,
hash: Index,
decoding: Index,
proving: Index,
}
impl Authenticator {
/// Creates a new authenticator.
///
/// # Arguments
///
/// * `encoding` - Ranges for encoding commitments.
/// * `hash` - Ranges for hash commitments.
/// * `partial` - The partial transcript.
pub(crate) fn new<'a>(
encoding: impl Iterator<Item = &'a (Direction, RangeSet<usize>)>,
hash: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
partial: Option<&PartialTranscript>,
) -> Self {
// Compute encoding index.
let mut encoding_sent = RangeSet::default();
let mut encoding_recv = RangeSet::default();
for (d, idx) in encoding {
match d {
Direction::Sent => encoding_sent.union_mut(idx),
Direction::Received => encoding_recv.union_mut(idx),
}
}
let encoding = Index::new(encoding_sent, encoding_recv);
// Compute hash index.
let mut hash_sent = RangeSet::default();
let mut hash_recv = RangeSet::default();
for (d, idx, _) in hash {
match d {
Direction::Sent => hash_sent.union_mut(idx),
Direction::Received => hash_recv.union_mut(idx),
}
}
let hash = Index {
sent: hash_sent,
recv: hash_recv,
};
// Compute decoding index.
let mut decoding_sent = RangeSet::default();
let mut decoding_recv = RangeSet::default();
if let Some(partial) = partial {
decoding_sent.union_mut(partial.sent_authed());
decoding_recv.union_mut(partial.received_authed());
}
let decoding = Index::new(decoding_sent, decoding_recv);
// Compute proving index.
let mut proving_sent = RangeSet::default();
let mut proving_recv = RangeSet::default();
proving_sent.union_mut(decoding.sent());
proving_sent.union_mut(encoding.sent());
proving_sent.union_mut(hash.sent());
proving_recv.union_mut(decoding.recv());
proving_recv.union_mut(encoding.recv());
proving_recv.union_mut(hash.recv());
let proving = Index::new(proving_sent, proving_recv);
Self {
encoding,
hash,
decoding,
proving,
}
}
/// Authenticates the sent plaintext, returning a proof of encryption and
/// writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_sent` - ZK AES Cipher for sent traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_sent(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_sent: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let missing_index = transcript_refs.compute_missing(Direction::Sent, self.proving.sent());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_sent,
Direction::Sent,
sent,
transcript_refs,
missing_index,
)
}
/// Authenticates the received plaintext, returning a proof of encryption
/// and writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_recv` - ZK AES Cipher for received traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_recv(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_recv: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let decoding_recv = self.decoding.recv();
let fully_decoded = decoding_recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
// If we only have decoding ranges, and the parts we are going to decode will
// complete to the full received transcript, then we do not need to
// authenticate, because this will be done by
// `crate::commit::decode::verify_transcript`, as it uses the server write
// key and iv for verification.
if decoding_recv == self.proving.recv() && fully_decoded == full_range {
return Ok(RecordProof::default());
}
let missing_index =
transcript_refs.compute_missing(Direction::Received, self.proving.recv());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_recv,
Direction::Received,
recv,
transcript_refs,
missing_index,
)
}
/// Returns the sent and received encoding ranges.
pub(crate) fn encoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.encoding.sent(), self.encoding.recv())
}
/// Returns the sent and received hash ranges.
pub(crate) fn hash(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.hash.sent(), self.hash.recv())
}
/// Returns the sent and received decoding ranges.
pub(crate) fn decoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.decoding.sent(), self.decoding.recv())
}
}
/// Authenticates parts of the transcript in zk.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes` - ZK AES Cipher.
/// * `direction` - The direction of the application data.
/// * `app_data` - The application data.
/// * `transcript_refs` - The transcript references.
/// * `missing_index` - The index which needs to be proven.
fn authenticate<'a>(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
direction: Direction,
app_data: impl Iterator<Item = &'a Record>,
transcript_refs: &mut TranscriptRefs,
missing_index: RangeSet<usize>,
) -> Result<RecordProof, AuthError> {
let mut record_idx = Range::default();
let mut ciphertexts = Vec::new();
for record in app_data {
let record_len = record.ciphertext.len();
record_idx.end += record_len;
if missing_index.is_disjoint(&record_idx) {
record_idx.start += record_len;
continue;
}
let (plaintext_ref, ciphertext_ref) =
zk_aes.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())?;
if let Role::Prover = zk_aes.role() {
let Some(plaintext) = record.plaintext.clone() else {
return Err(AuthError(ErrorRepr::MissingPlainText));
};
vm.assign(plaintext_ref, plaintext).map_err(AuthError::vm)?;
}
vm.commit(plaintext_ref).map_err(AuthError::vm)?;
let ciphertext = vm.decode(ciphertext_ref).map_err(AuthError::vm)?;
transcript_refs.add(direction, &record_idx, plaintext_ref);
ciphertexts.push((ciphertext, record.ciphertext.clone()));
record_idx.start += record_len;
}
let proof = RecordProof { ciphertexts };
Ok(proof)
}
#[derive(Debug, Clone, Default)]
struct Index {
sent: RangeSet<usize>,
recv: RangeSet<usize>,
}
impl Index {
fn new(sent: RangeSet<usize>, recv: RangeSet<usize>) -> Self {
Self { sent, recv }
}
fn sent(&self) -> &RangeSet<usize> {
&self.sent
}
fn recv(&self) -> &RangeSet<usize> {
&self.recv
}
}
/// Proof of encryption.
#[derive(Debug, Default)]
#[must_use]
#[allow(clippy::type_complexity)]
pub(crate) struct RecordProof {
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
}
impl RecordProof {
/// Verifies the proof.
pub(crate) fn verify(self) -> Result<(), AuthError> {
let Self { ciphertexts } = self;
for (mut ciphertext, expected) in ciphertexts {
let ciphertext = ciphertext
.try_recv()
.map_err(AuthError::vm)?
.ok_or(AuthError(ErrorRepr::MissingDecoding))?;
if ciphertext != expected {
return Err(AuthError(ErrorRepr::InvalidCiphertext));
}
}
Ok(())
}
}
/// Error for [`Authenticator`].
#[derive(Debug, thiserror::Error)]
#[error("transcript authentication error: {0}")]
pub(crate) struct AuthError(#[source] ErrorRepr);
impl AuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk-aes error: {0}")]
ZkAes(ZkAesCtrError),
#[error("decode error: {0}")]
Decode(DecodeError),
#[error("plaintext is missing in record")]
MissingPlainText,
#[error("decoded value is missing")]
MissingDecoding,
#[error("invalid ciphertext")]
InvalidCiphertext,
}
impl From<ZkAesCtrError> for AuthError {
fn from(value: ZkAesCtrError) -> Self {
Self(ErrorRepr::ZkAes(value))
}
}
impl From<DecodeError> for AuthError {
fn from(value: DecodeError) -> Self {
Self(ErrorRepr::Decode(value))
}
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
auth::{Authenticator, ErrorRepr},
transcript::TranscriptRefs,
},
zk_aes_ctr::ZkAesCtr,
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY, RECORD_SIZE},
hash::HashAlgId,
transcript::{ContentType, Direction, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_authenticator_sent(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_sent(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(600..1600));
prove_range.union_mut(&(800..2000));
prove_range.union_mut(&(2600..3700));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Sent), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Sent), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_recv(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, RECV_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, RECV_LEN).unwrap();
let _ = auth_prover
.auth_recv(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_recv(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(4000..4200));
prove_range.union_mut(&(5000..5800));
prove_range.union_mut(&(6800..RECV_LEN));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Received), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Received), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_sent_verify_fail(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
// Forge verifier transcript to check if verify fails.
// Use an index which is part of the proving range.
let forged = forged();
let proof = auth_verifier
.auth_sent(&mut verifier, &mut zk_verifier, &forged, &mut refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let err = proof.verify().unwrap_err();
assert!(matches!(err.0, ErrorRepr::InvalidCiphertext));
}
fn keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let sent = 600..1600;
let recv = 4000..4200;
(sent.into(), recv.into())
}
#[fixture]
fn encoding() -> Vec<(Direction, RangeSet<usize>)> {
let sent = 800..2000;
let recv = 5000..5800;
let encoding = vec![
(Direction::Sent, sent.into()),
(Direction::Received, recv.into()),
];
encoding
}
#[fixture]
fn hashes() -> Vec<(Direction, RangeSet<usize>, HashAlgId)> {
let sent = 2600..3700;
let recv = 6800..RECV_LEN;
let alg = HashAlgId::SHA256;
let hashes = vec![
(Direction::Sent, sent.into(), alg),
(Direction::Received, recv.into(), alg),
];
hashes
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 610;
let mut sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
sent[WRONG_BYTE_INDEX] = sent[WRONG_BYTE_INDEX].wrapping_add(1);
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
TranscriptRefs::new(sent_len, recv_len)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,615 @@
//! Selective disclosure.
use mpz_memory_core::{
Array, MemoryExt,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{Intersection, RangeSet, Subset, Union};
use tlsn_core::transcript::{ContentType, Direction, PartialTranscript, TlsTranscript};
use crate::commit::TranscriptRefs;
/// Decodes parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `transcript_refs` - The transcript references.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
transcript_refs: &mut TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding_ranges;
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for slice in sent_refs.into_iter() {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Sent, sent);
// If possible use server write key for decoding.
let fully_decoded = recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
if fully_decoded == full_range {
// Drop the future, we don't need it.
drop(vm.decode(key).map_err(DecodeError::vm)?);
drop(vm.decode(iv).map_err(DecodeError::vm)?);
transcript_refs.mark_decoded(Direction::Received, &full_range.into());
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for slice in recv_refs {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Received, recv);
}
Ok(())
}
/// Verifies parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `partial` - The partial transcript.
/// * `transcript_refs` - The transcript references.
/// * `transcript` - The TLS transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
partial: Option<&PartialTranscript>,
transcript_refs: &mut TranscriptRefs,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let (sent, recv) = decoding_ranges;
let mut authenticated_data = Vec::new();
// Add sent transcript parts.
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for data in sent_refs.into_iter() {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
// Add received transcript parts, if possible using key and iv.
if let (Some(key), Some(iv)) = (
vm.get(key).map_err(DecodeError::vm)?,
vm.get(iv).map_err(DecodeError::vm)?,
) {
let plaintext = verify_with_keys(key, iv, recv, transcript)?;
authenticated_data.extend_from_slice(&plaintext);
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for data in recv_refs {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
}
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in sent.iter_ranges() {
purported_data.extend_from_slice(&partial.sent_unsafe()[range]);
}
for range in recv.iter_ranges() {
purported_data.extend_from_slice(&partial.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(DecodeError(ErrorRepr::InconsistentTranscript));
}
Ok(())
}
/// Checks the transcript length.
///
/// # Arguments
///
/// * `partial` - The partial transcript.
/// * `transcript` - The TLS transcript.
pub(crate) fn check_transcript_length(
partial: Option<&PartialTranscript>,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let sent_len: usize = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len: usize = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
// Check ranges.
if partial.len_sent() != sent_len || partial.len_received() != recv_len {
return Err(DecodeError(ErrorRepr::VerifyTranscriptLength));
}
Ok(())
}
fn verify_with_keys(
key: [u8; 16],
iv: [u8; 4],
decoding_ranges: &RangeSet<usize>,
transcript: &TlsTranscript,
) -> Result<Vec<u8>, DecodeError> {
let mut plaintexts = Vec::with_capacity(decoding_ranges.len());
let mut position = 0_usize;
let recv_data = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv_data {
let current = position..position + record.ciphertext.len();
if !current.is_subset(decoding_ranges) {
position += record.ciphertext.len();
continue;
}
let nonce = record
.explicit_nonce
.clone()
.try_into()
.expect("explicit nonce should be 8 bytes");
let plaintext = aes_apply_keystream(key, iv, nonce, &record.ciphertext);
let record_decoding_range = decoding_ranges.intersection(&current);
for r in record_decoding_range.iter_ranges() {
let shifted = r.start - position..r.end - position;
plaintexts.extend_from_slice(&plaintext[shifted]);
}
position += record.ciphertext.len()
}
Ok(plaintexts)
}
fn aes_apply_keystream(key: [u8; 16], iv: [u8; 4], explicit_nonce: [u8; 8], msg: &[u8]) -> Vec<u8> {
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let start_ctr = 2;
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(&iv);
full_iv[4..12].copy_from_slice(&explicit_nonce);
let mut cipher = Ctr32BE::<Aes128>::new(&key.into(), &full_iv.into());
let mut out = msg.to_vec();
cipher
.try_seek(start_ctr * 16)
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut out);
out
}
/// A decoding error.
#[derive(Debug, thiserror::Error)]
#[error("decode error: {0}")]
pub(crate) struct DecodeError(#[source] ErrorRepr);
impl DecodeError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("missing partial transcript")]
MissingPartialTranscript,
#[error("length of partial transcript does not match expected length")]
VerifyTranscriptLength,
#[error("provided transcript does not match exptected")]
InconsistentTranscript,
#[error("trying to get plaintext, but it is missing")]
MissingPlaintext,
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
TranscriptRefs,
decode::{DecodeError, ErrorRepr, decode_transcript, verify_transcript},
},
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY},
transcript::{ContentType, Direction, PartialTranscript, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding.clone());
decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_fail(
decoding: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding.clone());
let err = decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
#[rstest]
#[tokio::test]
async fn test_decode_all(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding_full.clone());
decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_all_fail(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding_full.clone());
let err = decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
async fn decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
partial: PartialTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding;
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut transcript_refs_verifier = transcript_refs.clone();
let mut transcript_refs_prover = transcript_refs;
let key: [u8; 16] = KEY;
let iv: [u8; 4] = IV;
let (key_prover, iv_prover) = assign_keys(&mut prover, key, iv, Role::Prover);
let (key_verifier, iv_verifier) = assign_keys(&mut verifier, key, iv, Role::Verifier);
assign_transcript(
&mut prover,
Role::Prover,
&transcript,
&mut transcript_refs_prover,
);
assign_transcript(
&mut verifier,
Role::Verifier,
&transcript,
&mut transcript_refs_verifier,
);
decode_transcript(
&mut prover,
key_prover,
iv_prover,
(&sent, &recv),
&mut transcript_refs_prover,
)
.unwrap();
decode_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
&mut transcript_refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v),
)
.unwrap();
verify_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
Some(&partial),
&mut transcript_refs_verifier,
&transcript,
)
}
fn assign_keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
fn assign_transcript(
vm: &mut dyn Vm<Binary>,
role: Role,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) {
let mut pos = 0_usize;
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in sent {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Sent, &index, plaintext_ref);
pos += record.ciphertext.len();
}
pos = 0;
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Received, &index, plaintext_ref);
pos += record.ciphertext.len();
}
}
fn partial(
transcript: &TlsTranscript,
decoding: (RangeSet<usize>, RangeSet<usize>),
) -> PartialTranscript {
let (sent, recv) = decoding;
transcript.to_transcript().unwrap().to_partial(sent, recv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
sent.union_mut(&(600..1100));
sent.union_mut(&(3450..4000));
recv.union_mut(&(2000..3000));
recv.union_mut(&(4800..4900));
recv.union_mut(&(6000..7000));
(sent, recv)
}
#[fixture]
fn decoding_full(transcript: TlsTranscript) -> (RangeSet<usize>, RangeSet<usize>) {
let transcript = transcript.to_transcript().unwrap();
let (len_sent, len_recv) = transcript.len();
let sent = (0..len_sent).into();
let recv = (0..len_recv).into();
(sent, recv)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 2200;
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
recv[WRONG_BYTE_INDEX] = recv[WRONG_BYTE_INDEX].wrapping_add(1);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
TranscriptRefs::new(sent_len, recv_len)
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,530 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_memory_core::{
MemoryType, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{RangeSet, Subset, UnionMut};
use serde::{Deserialize, Serialize};
use tlsn_core::{
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256, TypedHash},
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingProvider, EncodingProviderError, EncodingTree,
EncodingTreeError, new_encoder,
},
},
};
use crate::commit::transcript::TranscriptRefs;
/// Bytes of encoding, per byte.
pub(crate) const ENCODING_SIZE: usize = 128;
pub(crate) trait EncodingVm<T: MemoryType>: EncodingMemory<T> + Vm<T> {}
impl<T: MemoryType, U> EncodingVm<T> for U where U: EncodingMemory<T> + Vm<T> {}
pub(crate) trait EncodingMemory<T: MemoryType> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8>;
}
impl<T> EncodingMemory<Binary> for mpz_zk::Prover<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let macs = self.get_macs(v).expect("macs should be available");
encodings.extend(macs.iter().flat_map(|mac| mac.as_bytes()));
}
encodings
}
}
impl<T> EncodingMemory<Binary> for mpz_zk::Verifier<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let keys = self.get_keys(v).expect("keys should be available");
encodings.extend(keys.iter().flat_map(|key| key.as_block().as_bytes()));
}
encodings
}
}
/// The encoding adjustments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Encodings {
pub(crate) sent: Vec<u8>,
pub(crate) recv: Vec<u8>,
}
/// Creates encoding commitments.
#[derive(Debug)]
pub(crate) struct EncodingCreator {
hash_id: Option<HashAlgId>,
sent: RangeSet<usize>,
recv: RangeSet<usize>,
idxs: Vec<(Direction, RangeSet<usize>)>,
}
impl EncodingCreator {
/// Creates a new encoding creator.
///
/// # Arguments
///
/// * `hash_id` - The id of the hash algorithm.
/// * `idxs` - The indices for encoding commitments.
pub(crate) fn new(hash_id: Option<HashAlgId>, idxs: Vec<(Direction, RangeSet<usize>)>) -> Self {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (direction, idx) in idxs.iter() {
for range in idx.iter_ranges() {
match direction {
Direction::Sent => sent.union_mut(&range),
Direction::Received => recv.union_mut(&range),
}
}
}
Self {
hash_id,
sent,
recv,
idxs,
}
}
/// Receives the encodings using the provided MACs from the encoding memory.
///
/// The MACs must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `encodings` - The encoding adjustments.
/// * `transcript_refs` - The transcript references.
pub(crate) fn receive(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
encodings: Encodings,
transcript_refs: &TranscriptRefs,
) -> Result<(TypedHash, EncodingTree), EncodingError> {
let Some(id) = self.hash_id else {
return Err(EncodingError(ErrorRepr::MissingHashId));
};
let hasher: &(dyn HashAlgorithm + Send + Sync) = match id {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(EncodingError(ErrorRepr::UnsupportedHashAlg(alg)));
}
};
let Encodings {
sent: mut sent_adjust,
recv: mut recv_adjust,
} = encodings;
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_adjust, &mut recv_adjust)?;
let provider = Provider::new(sent_adjust, &self.sent, recv_adjust, &self.recv);
let tree = EncodingTree::new(hasher, self.idxs.iter(), &provider)?;
let root = tree.root();
Ok((root, tree))
}
/// Transfers the encodings using the provided secret and keys from the
/// encoding memory.
///
/// The keys must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `secret` - The encoder secret.
/// * `transcript_refs` - The transcript references.
pub(crate) fn transfer(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
secret: EncoderSecret,
transcript_refs: &TranscriptRefs,
) -> Result<Encodings, EncodingError> {
let encoder = new_encoder(&secret);
let mut sent_zero = Vec::with_capacity(self.sent.len() * ENCODING_SIZE);
let mut recv_zero = Vec::with_capacity(self.recv.len() * ENCODING_SIZE);
for range in self.sent.iter_ranges() {
encoder.encode_range(Direction::Sent, range, &mut sent_zero);
}
for range in self.recv.iter_ranges() {
encoder.encode_range(Direction::Received, range, &mut recv_zero);
}
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_zero, &mut recv_zero)?;
let encodings = Encodings {
sent: sent_zero,
recv: recv_zero,
};
Ok(encodings)
}
}
/// Adjust encodings by transcript references.
///
/// # Arguments
///
/// * `sent` - The encodings for the sent bytes.
/// * `recv` - The encodings for the received bytes.
/// * `sent_adjust` - The adjustment bytes for the encodings of the sent bytes.
/// * `recv_adjust` - The adjustment bytes for the encodings of the received
/// bytes.
fn adjust(
sent: &[u8],
recv: &[u8],
sent_adjust: &mut [u8],
recv_adjust: &mut [u8],
) -> Result<(), EncodingError> {
assert_eq!(sent.len() % ENCODING_SIZE, 0);
assert_eq!(recv.len() % ENCODING_SIZE, 0);
if sent_adjust.len() != sent.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Sent,
expected: sent.len(),
got: sent_adjust.len(),
}
.into());
}
if recv_adjust.len() != recv.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Received,
expected: recv.len(),
got: recv_adjust.len(),
}
.into());
}
sent_adjust
.iter_mut()
.zip(sent)
.for_each(|(adjust, enc)| *adjust ^= enc);
recv_adjust
.iter_mut()
.zip(recv)
.for_each(|(adjust, enc)| *adjust ^= enc);
Ok(())
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
sent_range: RangeSet<usize>,
recv: Vec<u8>,
recv_range: RangeSet<usize>,
}
impl Provider {
fn new(
sent: Vec<u8>,
sent_range: &RangeSet<usize>,
recv: Vec<u8>,
recv_range: &RangeSet<usize>,
) -> Self {
assert_eq!(
sent.len(),
sent_range.len() * ENCODING_SIZE,
"length of sent encodings and their index length do not match"
);
assert_eq!(
recv.len(),
recv_range.len() * ENCODING_SIZE,
"length of received encodings and their index length do not match"
);
Self {
sent,
sent_range: sent_range.clone(),
recv,
recv_range: recv_range.clone(),
}
}
fn adjust(
&self,
direction: Direction,
range: &Range<usize>,
) -> Result<Range<usize>, EncodingProviderError> {
let internal_range = match direction {
Direction::Sent => &self.sent_range,
Direction::Received => &self.recv_range,
};
if !range.is_subset(internal_range) {
return Err(EncodingProviderError);
}
let shift = internal_range
.iter()
.take_while(|&el| el < range.start)
.count();
let translated = Range {
start: shift,
end: shift + range.len(),
};
Ok(translated)
}
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let range = self.adjust(direction, &range)?;
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
dest.extend_from_slice(&encodings[start..end]);
Ok(())
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("incorrect adjustment count for {direction}: expected {expected}, got {got}")]
IncorrectAdjustCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
#[error("missing hash id")]
MissingHashId,
#[error("unsupported hash algorithm for encoding commitment: {0}")]
UnsupportedHashAlg(HashAlgId),
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ops::Range;
use crate::commit::{
encoding::{ENCODING_SIZE, EncodingCreator, Encodings, Provider},
transcript::TranscriptRefs,
};
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_memory_core::{
FromRaw, Slice, ToRaw, Vector,
binary::{Binary, U8},
};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
hash::{HashAlgId, HashProvider},
transcript::{
Direction,
encoding::{EncoderSecret, EncodingCommitment, EncodingProvider},
},
};
#[rstest]
fn test_encoding_adjust(
index: (RangeSet<usize>, RangeSet<usize>),
transcript_refs: TranscriptRefs,
encoding_idxs: Vec<(Direction, RangeSet<usize>)>,
) {
let creator = EncodingCreator::new(Some(HashAlgId::SHA256), encoding_idxs);
let mock_memory = MockEncodingMemory;
let delta = Delta::new(Block::ONES);
let seed = [1_u8; 32];
let secret = EncoderSecret::new(seed, delta.as_block().to_bytes());
let adjustments = creator
.transfer(&mock_memory, secret, &transcript_refs)
.unwrap();
let (root, tree) = creator
.receive(&mock_memory, adjustments, &transcript_refs)
.unwrap();
// Check correctness of encoding protocol.
let mut idxs = Vec::new();
let (sent_range, recv_range) = index;
idxs.push((Direction::Sent, sent_range.clone()));
idxs.push((Direction::Received, recv_range.clone()));
let commitment = EncodingCommitment { root, secret };
let proof = tree.proof(idxs.iter()).unwrap();
// Here we set the trancscript plaintext to just be zeroes, which is possible
// because it is not determined by the encodings so far.
let sent = vec![0_u8; transcript_refs.max_len(Direction::Sent)];
let recv = vec![0_u8; transcript_refs.max_len(Direction::Received)];
let (idx_sent, idx_recv) = proof
.verify_with_provider(&HashProvider::default(), &commitment, &sent, &recv)
.unwrap();
assert_eq!(idx_sent, idxs[0].1);
assert_eq!(idx_recv, idxs[1].1);
}
#[rstest]
fn test_encoding_provider(index: (RangeSet<usize>, RangeSet<usize>), encodings: Encodings) {
let (sent_range, recv_range) = index;
let Encodings { sent, recv } = encodings;
let provider = Provider::new(sent, &sent_range, recv, &recv_range);
let mut encodings_sent = Vec::new();
let mut encodings_recv = Vec::new();
provider
.provide_encoding(Direction::Sent, 16..24, &mut encodings_sent)
.unwrap();
provider
.provide_encoding(Direction::Received, 56..64, &mut encodings_recv)
.unwrap();
let expected_sent = generate_encodings((16..24).into());
let expected_recv = generate_encodings((56..64).into());
assert_eq!(expected_sent, encodings_sent);
assert_eq!(expected_recv, encodings_recv);
}
#[fixture]
fn transcript_refs(index: (RangeSet<usize>, RangeSet<usize>)) -> TranscriptRefs {
let mut transcript_refs = TranscriptRefs::new(40, 64);
let dummy = |range: Range<usize>| {
Vector::<U8>::from_raw(Slice::from_range_unchecked(8 * range.start..8 * range.end))
};
for range in index.0.iter_ranges() {
transcript_refs.add(Direction::Sent, &range, dummy(range.clone()));
}
for range in index.1.iter_ranges() {
transcript_refs.add(Direction::Received, &range, dummy(range.clone()));
}
transcript_refs
}
#[fixture]
fn encodings(index: (RangeSet<usize>, RangeSet<usize>)) -> Encodings {
let sent = generate_encodings(index.0);
let recv = generate_encodings(index.1);
Encodings { sent, recv }
}
#[fixture]
fn encoding_idxs(
index: (RangeSet<usize>, RangeSet<usize>),
) -> Vec<(Direction, RangeSet<usize>)> {
let (sent, recv) = index;
vec![(Direction::Sent, sent), (Direction::Received, recv)]
}
#[fixture]
fn index() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
sent.union_mut(&(0..8));
sent.union_mut(&(16..24));
sent.union_mut(&(32..40));
let mut recv = RangeSet::default();
recv.union_mut(&(40..48));
recv.union_mut(&(56..64));
(sent, recv)
}
#[derive(Clone, Copy)]
struct MockEncodingMemory;
impl EncodingMemory<Binary> for MockEncodingMemory {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let ranges: Vec<Range<usize>> = values
.iter()
.map(|r| {
let range = r.to_raw().to_range();
range.start / 8..range.end / 8
})
.collect();
let ranges: RangeSet<usize> = ranges.into();
generate_encodings(ranges)
}
}
fn generate_encodings(index: RangeSet<usize>) -> Vec<u8> {
let mut out = Vec::new();
for el in index.iter() {
out.extend_from_slice(&[el as u8; ENCODING_SIZE]);
}
out
}
}

View File

@@ -20,34 +20,171 @@ use tlsn_core::{
use crate::{Role, commit::transcript::TranscriptRefs};
/// Future which will resolve to the committed hash values.
/// Creates plaintext hashes.
#[derive(Debug)]
pub(crate) struct HashCommitFuture {
#[allow(clippy::type_complexity)]
futs: Vec<(
Direction,
RangeSet<usize>,
HashAlgId,
DecodeFutureTyped<BitVec, Vec<u8>>,
)>,
pub(crate) struct PlaintextHasher {
ranges: Vec<HashRange>,
}
impl HashCommitFuture {
impl PlaintextHasher {
/// Creates a new instance.
///
/// # Arguments
///
/// * `indices` - The hash indices.
pub(crate) fn new<'a>(
indices: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
) -> Self {
let mut ranges = Vec::new();
for (direction, index, id) in indices {
let hash_range = HashRange::new(*direction, index.clone(), *id);
ranges.push(hash_range);
}
Self { ranges }
}
/// Prove plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn prove(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<(HashFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Prover, transcript_refs)?;
let mut futures = Vec::new();
let mut secrets = Vec::new();
for ((range, hash_ref), blinder_ref) in self.ranges.iter().zip(hash_refs).zip(blinders) {
let blinder: Blinder = rand::random();
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut));
secrets.push(PlaintextHashSecret {
direction: range.direction,
idx: range.range.clone(),
blinder,
alg: range.id,
});
}
let hashes = HashFuture { futures };
Ok((hashes, secrets))
}
/// Verify plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn verify(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<HashFuture, HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Verifier, transcript_refs)?;
let mut futures = Vec::new();
for ((range, hash_ref), blinder) in self.ranges.iter().zip(hash_refs).zip(blinders) {
vm.commit(blinder)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut))
}
let hashes = HashFuture { futures };
Ok(hashes)
}
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn commit(
vm: &mut dyn Vm<Binary>,
ranges: &[HashRange],
role: Role,
refs: &TranscriptRefs,
) -> Result<(Vec<Array<U8, 32>>, Vec<Vector<U8>>), HashCommitError> {
let mut hashers = HashMap::new();
let mut hash_refs = Vec::new();
let mut blinders = Vec::new();
for HashRange {
direction,
range,
id,
} in ranges.iter()
{
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
}
let hash = match *id {
HashAlgId::SHA256 => {
let mut hasher = if let Some(hasher) = hashers.get(id).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(id, hasher.clone());
hasher
};
for plaintext in refs.get(*direction, range) {
hasher.update(&plaintext);
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
id => {
return Err(HashCommitError::unsupported_alg(id));
}
};
hash_refs.push(hash);
blinders.push(blinder);
}
Ok((hash_refs, blinders))
}
/// Future which will resolve to the committed hash values.
#[derive(Debug)]
pub(crate) struct HashFuture {
futures: Vec<(HashRange, DecodeFutureTyped<BitVec, Vec<u8>>)>,
}
impl HashFuture {
/// Tries to receive the value, returning an error if the value is not
/// ready.
pub(crate) fn try_recv(self) -> Result<Vec<PlaintextHash>, HashCommitError> {
let mut output = Vec::new();
for (direction, idx, alg, mut fut) in self.futs {
for (hash_range, mut fut) in self.futures {
let hash = fut
.try_recv()
.map_err(|_| HashCommitError::decode())?
.ok_or_else(HashCommitError::decode)?;
output.push(PlaintextHash {
direction,
idx,
direction: hash_range.direction,
idx: hash_range.range,
hash: TypedHash {
alg,
alg: hash_range.id,
value: Hash::try_from(hash).map_err(HashCommitError::convert)?,
},
});
@@ -57,107 +194,21 @@ impl HashCommitFuture {
}
}
/// Prove plaintext hash commitments.
pub(crate) fn prove_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<(HashCommitFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let mut futs = Vec::new();
let mut secrets = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Prover, refs, idxs)?
{
let blinder: Blinder = rand::random();
#[derive(Debug, Clone)]
struct HashRange {
direction: Direction,
range: RangeSet<usize>,
id: HashAlgId,
}
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx.clone(), alg, hash_fut));
secrets.push(PlaintextHashSecret {
impl HashRange {
fn new(direction: Direction, range: RangeSet<usize>, id: HashAlgId) -> Self {
Self {
direction,
idx,
blinder,
alg,
});
}
Ok((HashCommitFuture { futs }, secrets))
}
/// Verify plaintext hash commitments.
pub(crate) fn verify_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<HashCommitFuture, HashCommitError> {
let mut futs = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Verifier, refs, idxs)?
{
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx, alg, hash_fut));
}
Ok(HashCommitFuture { futs })
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn hash_commit_inner(
vm: &mut dyn Vm<Binary>,
role: Role,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<
Vec<(
Direction,
RangeSet<usize>,
HashAlgId,
Array<U8, 32>,
Vector<U8>,
)>,
HashCommitError,
> {
let mut output = Vec::new();
let mut hashers = HashMap::new();
for (direction, idx, alg) in idxs {
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
range,
id,
}
let hash = match alg {
HashAlgId::SHA256 => {
let mut hasher = if let Some(hasher) = hashers.get(&alg).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, hasher.clone());
hasher
};
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
hasher.update(&plaintext);
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
alg => {
return Err(HashCommitError::unsupported_alg(alg));
}
};
output.push((direction, idx, alg, hash, blinder));
}
Ok(output)
}
/// Error type for hash commitments.
@@ -206,3 +257,148 @@ impl From<VmError> for HashCommitError {
Self(ErrorRepr::Vm(value))
}
}
#[cfg(test)]
mod test {
use crate::{
Role,
commit::{hash::PlaintextHasher, transcript::TranscriptRefs},
};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use sha2::Digest;
use tlsn_core::{hash::HashAlgId, transcript::Direction};
#[rstest]
#[tokio::test]
async fn test_hasher() {
let mut sent1 = RangeSet::default();
sent1.union_mut(&(1..6));
sent1.union_mut(&(11..16));
let mut sent2 = RangeSet::default();
sent2.union_mut(&(22..25));
let mut recv = RangeSet::default();
recv.union_mut(&(20..25));
let hash_ranges = [
(Direction::Sent, sent1, HashAlgId::SHA256),
(Direction::Sent, sent2, HashAlgId::SHA256),
(Direction::Received, recv, HashAlgId::SHA256),
];
let mut refs_prover = TranscriptRefs::new(1000, 1000);
let mut refs_verifier = TranscriptRefs::new(1000, 1000);
let values = [
b"abcde".to_vec(),
b"vwxyz".to_vec(),
b"xxx".to_vec(),
b"12345".to_vec(),
];
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut values_iter = values.iter();
for (direction, idx, _) in hash_ranges.iter() {
for range in idx.iter_ranges() {
let value = values_iter.next().unwrap();
let ref_prover = assign(Role::Prover, &mut prover, value.clone());
refs_prover.add(*direction, &range, ref_prover);
let ref_verifier = assign(Role::Verifier, &mut verifier, value.clone());
refs_verifier.add(*direction, &range, ref_verifier);
}
}
let hasher_prover = PlaintextHasher::new(hash_ranges.iter());
let hasher_verifier = PlaintextHasher::new(hash_ranges.iter());
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let (prover_hashes, prover_secrets) =
hasher_prover.prove(&mut prover, &refs_prover).unwrap();
let verifier_hashes = hasher_verifier
.verify(&mut verifier, &refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let prover_hashes = prover_hashes.try_recv().unwrap();
let verifier_hashes = verifier_hashes.try_recv().unwrap();
assert_eq!(prover_hashes, verifier_hashes);
let values_per_commitment = [b"abcdevwxyz".to_vec(), b"xxx".to_vec(), b"12345".to_vec()];
for ((value, hash), secret) in values_per_commitment
.iter()
.zip(prover_hashes)
.zip(prover_secrets)
{
let blinder = secret.blinder.as_bytes();
let mut blinded_value = value.clone();
blinded_value.extend_from_slice(blinder);
let expected_hash = sha256(&blinded_value);
let hash: Vec<u8> = hash.hash.value.into();
assert_eq!(expected_hash, hash);
}
}
fn assign(role: Role, vm: &mut dyn Vm<Binary>, value: Vec<u8>) -> Vector<U8> {
let reference: Vector<U8> = vm.alloc_vec(value.len()).unwrap();
if let Role::Prover = role {
vm.mark_private(reference).unwrap();
vm.assign(reference, value).unwrap();
} else {
vm.mark_blind(reference).unwrap();
}
vm.commit(reference).unwrap();
reference
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = sha2::Sha256::default();
hasher.update(data);
hasher.finalize().as_slice().to_vec()
}
#[fixture]
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
}

View File

@@ -1,211 +1,473 @@
use mpz_memory_core::{
MemoryExt, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError};
use rangeset::{Intersection, RangeSet};
use tlsn_core::transcript::{Direction, PartialTranscript};
//! Transcript reference storage.
use std::ops::Range;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{Difference, Disjoint, RangeSet, Subset, UnionMut};
use tlsn_core::transcript::Direction;
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub(crate) struct TranscriptRefs {
sent: Vec<Vector<U8>>,
recv: Vec<Vector<U8>>,
sent: RefStorage,
recv: RefStorage,
}
impl TranscriptRefs {
pub(crate) fn new(sent: Vec<Vector<U8>>, recv: Vec<Vector<U8>>) -> Self {
/// Creates a new instance.
///
/// # Arguments
///
/// `sent_max_len` - The maximum length of the sent transcript in bytes.
/// `recv_max_len` - The maximum length of the received transcript in bytes.
pub(crate) fn new(sent_max_len: usize, recv_max_len: usize) -> Self {
let sent = RefStorage::new(sent_max_len);
let recv = RefStorage::new(recv_max_len);
Self { sent, recv }
}
/// Returns the sent plaintext references.
pub(crate) fn sent(&self) -> &[Vector<U8>] {
&self.sent
/// Adds new references to the transcript refs.
///
/// New transcript references are only added if none of them are already
/// present.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
/// * `refs` - The new transcript refs.
pub(crate) fn add(&mut self, direction: Direction, index: &Range<usize>, refs: Vector<U8>) {
match direction {
Direction::Sent => self.sent.add(index, refs),
Direction::Received => self.recv.add(index, refs),
}
}
/// Returns the received plaintext references.
pub(crate) fn recv(&self) -> &[Vector<U8>] {
&self.recv
/// Marks references of the transcript as decoded.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn mark_decoded(&mut self, direction: Direction, index: &RangeSet<usize>) {
match direction {
Direction::Sent => self.sent.mark_decoded(index),
Direction::Received => self.recv.mark_decoded(index),
}
}
/// Returns the transcript lengths.
pub(crate) fn len(&self) -> (usize, usize) {
let sent = self.sent.iter().map(|v| v.len()).sum();
let recv = self.recv.iter().map(|v| v.len()).sum();
(sent, recv)
/// Returns plaintext references for some index.
///
/// Queries that cannot or only partially be satisfied will return an empty
/// vector.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn get(&self, direction: Direction, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
match direction {
Direction::Sent => self.sent.get(index),
Direction::Received => self.recv.get(index),
}
}
/// Returns VM references for the given direction and index, otherwise
/// `None` if the index is out of bounds.
pub(crate) fn get(
/// Computes the subset of `index` which is missing.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn compute_missing(
&self,
direction: Direction,
idx: &RangeSet<usize>,
) -> Option<Vec<Vector<U8>>> {
if idx.is_empty() {
return Some(Vec::new());
index: &RangeSet<usize>,
) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.compute_missing(index),
Direction::Received => self.recv.compute_missing(index),
}
}
/// Returns the maximum length of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn max_len(&self, direction: Direction) -> usize {
match direction {
Direction::Sent => self.sent.max_len(),
Direction::Received => self.recv.max_len(),
}
}
/// Returns the decoded ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn decoded(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.decoded(),
Direction::Received => self.recv.decoded(),
}
}
/// Returns the set ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
#[cfg(test)]
pub(crate) fn index(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.index(),
Direction::Received => self.recv.index(),
}
}
}
/// Inner storage for transcript references.
///
/// Saves transcript references by maintaining an `index` and an `offset`. The
/// offset translates from `index` to some memory location and contains
/// information about possibly non-contigious memory locations. The storage is
/// bit-addressed but the API works with ranges over bytes.
#[derive(Debug, Clone)]
struct RefStorage {
index: RangeSet<usize>,
decoded: RangeSet<usize>,
offset: Vec<isize>,
max_len: usize,
}
impl RefStorage {
fn new(max_len: usize) -> Self {
Self {
index: RangeSet::default(),
decoded: RangeSet::default(),
offset: Vec::default(),
max_len: 8 * max_len,
}
}
fn add(&mut self, index: &Range<usize>, data: Vector<U8>) {
assert!(
index.start < index.end,
"Range should be valid for adding to reference storage"
);
assert_eq!(
index.len(),
data.len(),
"Provided index and vm references should have the same length"
);
let bit_index = 8 * index.start..8 * index.end;
assert!(
bit_index.is_disjoint(&self.index),
"Parts of the provided index have already been computed"
);
assert!(
bit_index.end <= self.max_len,
"Provided index should be smaller than max_len"
);
if bit_index.end > self.offset.len() {
self.offset.resize(bit_index.end, 0);
}
let refs = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let mem_address = data.to_raw().ptr().as_usize() as isize;
let offset = mem_address - bit_index.start as isize;
// Computes the transcript range for each reference.
let mut start = 0;
let mut slice_iter = refs.iter().map(move |slice| {
let out = (slice, start..start + slice.len());
start += slice.len();
out
});
self.index.union_mut(&bit_index);
self.offset[bit_index].fill(offset);
}
let mut slices = Vec::new();
let (mut slice, mut slice_range) = slice_iter.next()?;
for range in idx.iter_ranges() {
loop {
if let Some(intersection) = slice_range.intersection(&range) {
let start = intersection.start - slice_range.start;
let end = intersection.end - slice_range.start;
slices.push(slice.get(start..end).expect("range should be in bounds"));
fn mark_decoded(&mut self, index: &RangeSet<usize>) {
let bit_index = to_bit_index(index);
self.decoded.union_mut(&bit_index);
}
fn get(&self, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
let bit_index = to_bit_index(index);
if bit_index.is_empty() || !bit_index.is_subset(&self.index) {
return Vec::new();
}
// Partition rangeset into ranges mapping to possibly disjunct memory locations.
//
// If the offset changes during iteration of a single range, it means that the
// backing memory is non-contigious and we need to split that range.
let mut transcript_refs = Vec::new();
for idx in bit_index.iter_ranges() {
let mut start = idx.start;
let mut end = idx.start;
let mut offset = self.offset[start];
for k in idx {
let next_offset = self.offset[k];
if next_offset == offset {
end += 1;
continue;
}
// Proceed to next range if the current slice extends beyond. Otherwise, proceed
// to the next slice.
if range.end <= slice_range.end {
break;
} else {
(slice, slice_range) = slice_iter.next()?;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
start = k;
end = k + 1;
offset = next_offset;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
}
Some(slices)
transcript_refs
}
fn compute_missing(&self, index: &RangeSet<usize>) -> RangeSet<usize> {
let byte_index = to_byte_index(&self.index);
index.difference(&byte_index)
}
fn decoded(&self) -> RangeSet<usize> {
to_byte_index(&self.decoded)
}
fn max_len(&self) -> usize {
self.max_len / 8
}
#[cfg(test)]
fn index(&self) -> RangeSet<usize> {
to_byte_index(&self.index)
}
}
/// Decodes the transcript.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
sent: &RangeSet<usize>,
recv: &RangeSet<usize>,
refs: &TranscriptRefs,
) -> Result<(), VmError> {
let sent_refs = refs.get(Direction::Sent, sent).expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, recv)
.expect("index is in bounds");
fn to_bit_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut bit_index = RangeSet::default();
for slice in sent_refs.into_iter().chain(recv_refs) {
// Drop the future, we don't need it.
drop(vm.decode(slice)?);
for r in index.iter_ranges() {
bit_index.union_mut(&(8 * r.start..8 * r.end));
}
Ok(())
bit_index
}
/// Verifies a partial transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
transcript: &PartialTranscript,
refs: &TranscriptRefs,
) -> Result<(), InconsistentTranscript> {
let sent_refs = refs
.get(Direction::Sent, transcript.sent_authed())
.expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, transcript.received_authed())
.expect("index is in bounds");
fn to_byte_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut byte_index = RangeSet::default();
let mut authenticated_data = Vec::new();
for data in sent_refs.into_iter().chain(recv_refs) {
let plaintext = vm
.get(data)
.expect("reference is valid")
.expect("plaintext is decoded");
authenticated_data.extend_from_slice(&plaintext);
for r in index.iter_ranges() {
let start = r.start;
let end = r.end;
assert!(
start.trailing_zeros() >= 3,
"start range should be divisible by 8"
);
assert!(
end.trailing_zeros() >= 3,
"end range should be divisible by 8"
);
let start = start >> 3;
let end = end >> 3;
byte_index.union_mut(&(start..end));
}
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in transcript.sent_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.sent_unsafe()[range]);
}
for range in transcript.received_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(InconsistentTranscript {});
}
Ok(())
byte_index
}
/// Error for [`verify_transcript`].
#[derive(Debug, thiserror::Error)]
#[error("inconsistent transcript")]
pub(crate) struct InconsistentTranscript {}
#[cfg(test)]
mod tests {
use super::TranscriptRefs;
use mpz_memory_core::{FromRaw, Slice, Vector, binary::U8};
use rangeset::RangeSet;
use crate::commit::transcript::RefStorage;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use std::ops::Range;
use tlsn_core::transcript::Direction;
// TRANSCRIPT_REFS:
//
// 48..96 -> 6 slots
// 112..176 -> 8 slots
// 240..288 -> 6 slots
// 352..392 -> 5 slots
// 440..480 -> 5 slots
const TRANSCRIPT_REFS: &[Range<usize>] = &[48..96, 112..176, 240..288, 352..392, 440..480];
#[rstest]
fn test_storage_add(
max_len: usize,
ranges: [Range<usize>; 6],
offsets: [isize; 6],
storage: RefStorage,
) {
let bit_ranges: Vec<Range<usize>> = ranges.iter().map(|r| 8 * r.start..8 * r.end).collect();
let bit_offsets: Vec<isize> = offsets.iter().map(|o| 8 * o).collect();
const IDXS: &[Range<usize>] = &[0..4, 5..10, 14..16, 16..28];
let mut expected_index: RangeSet<usize> = RangeSet::default();
// 1. Take slots 0..4, 4 slots -> 48..80 (4)
// 2. Take slots 5..10, 5 slots -> 88..96 (1) + 112..144 (4)
// 3. Take slots 14..16, 2 slots -> 240..256 (2)
// 4. Take slots 16..28, 12 slots -> 256..288 (4) + 352..392 (5) + 440..464 (3)
//
// 5. Merge slots 240..256 and 256..288 => 240..288 and get EXPECTED_REFS
const EXPECTED_REFS: &[Range<usize>] =
&[48..80, 88..96, 112..144, 240..288, 352..392, 440..464];
expected_index.union_mut(&bit_ranges[0]);
expected_index.union_mut(&bit_ranges[1]);
#[test]
fn test_transcript_refs_get() {
let transcript_refs: Vec<Vector<U8>> = TRANSCRIPT_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
expected_index.union_mut(&bit_ranges[2]);
expected_index.union_mut(&bit_ranges[3]);
let transcript_refs = TranscriptRefs {
sent: transcript_refs.clone(),
recv: transcript_refs,
};
expected_index.union_mut(&bit_ranges[4]);
expected_index.union_mut(&bit_ranges[5]);
assert_eq!(storage.index, expected_index);
let vm_refs = transcript_refs
.get(Direction::Sent, &RangeSet::from(IDXS))
.unwrap();
let end = expected_index.end().unwrap();
let mut expected_offset = vec![0_isize; end];
let expected_refs: Vec<Vector<U8>> = EXPECTED_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
expected_offset[bit_ranges[0].clone()].fill(bit_offsets[0]);
expected_offset[bit_ranges[1].clone()].fill(bit_offsets[1]);
assert_eq!(
vm_refs.len(),
expected_refs.len(),
"Length of actual and expected refs are not equal"
);
expected_offset[bit_ranges[2].clone()].fill(bit_offsets[2]);
expected_offset[bit_ranges[3].clone()].fill(bit_offsets[3]);
for (&expected, actual) in expected_refs.iter().zip(vm_refs) {
assert_eq!(expected, actual);
expected_offset[bit_ranges[4].clone()].fill(bit_offsets[4]);
expected_offset[bit_ranges[5].clone()].fill(bit_offsets[5]);
assert_eq!(storage.offset, expected_offset);
assert_eq!(storage.decoded, RangeSet::default());
assert_eq!(storage.max_len, 8 * max_len);
}
#[rstest]
fn test_storage_get(ranges: [Range<usize>; 6], offsets: [isize; 6], storage: RefStorage) {
let mut index = RangeSet::default();
ranges.iter().for_each(|r| index.union_mut(r));
let data = storage.get(&index);
let mut data_recovered = Vec::new();
for (r, o) in ranges.iter().zip(offsets) {
data_recovered.push(vec(r.start as isize + o..r.end as isize + o));
}
// Merge possibly adjacent vectors.
//
// Two vectors are adjacent if
//
// - vectors are adjacent in memory.
// - transcript ranges of those vectors are adjacent, too.
let mut range_iter = ranges.iter();
let mut vec_iter = data_recovered.iter();
let mut data_expected = Vec::new();
let mut current_vec = vec_iter.next().unwrap().to_raw().to_range();
let mut current_range = range_iter.next().unwrap();
for (r, v) in range_iter.zip(vec_iter) {
let v_range = v.to_raw().to_range();
let start = v_range.start;
let end = v_range.end;
if current_vec.end == start && current_range.end == r.start {
current_vec.end = end;
} else {
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
current_vec = start..end;
current_range = r;
}
}
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
assert_eq!(data, data_expected);
}
#[rstest]
fn test_storage_compute_missing(storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(6..12));
range.union_mut(&(18..21));
range.union_mut(&(22..25));
range.union_mut(&(50..60));
let missing = storage.compute_missing(&range);
let mut missing_expected = RangeSet::default();
missing_expected.union_mut(&(8..12));
missing_expected.union_mut(&(20..21));
missing_expected.union_mut(&(50..60));
assert_eq!(missing, missing_expected);
}
#[rstest]
fn test_mark_decoded(mut storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(14..17));
range.union_mut(&(30..37));
storage.mark_decoded(&range);
let decoded = storage.decoded();
assert_eq!(range, decoded);
}
#[fixture]
fn max_len() -> usize {
1000
}
#[fixture]
fn ranges() -> [Range<usize>; 6] {
let r1 = 0..5;
let r2 = 5..8;
let r3 = 12..20;
let r4 = 22..26;
let r5 = 30..35;
let r6 = 35..38;
[r1, r2, r3, r4, r5, r6]
}
#[fixture]
fn offsets() -> [isize; 6] {
[7, 9, 20, 18, 30, 30]
}
// expected memory ranges: 8 * ranges + 8 * offsets
// 1. 56..96 do not merge with next one, because not adjacent in memory
// 2. 112..136
// 3. 256..320 do not merge with next one, adjacent in memory, but the ranges
// itself are not
// 4. 320..352
// 5. 480..520 merge with next one
// 6 520..544
//
//
// 1. 56..96, length: 5
// 2. 112..136, length: 3
// 3. 256..320, length: 8
// 4. 320..352, length: 4
// 5. 480..544, length: 8
#[fixture]
fn storage(max_len: usize, ranges: [Range<usize>; 6], offsets: [isize; 6]) -> RefStorage {
let [r1, r2, r3, r4, r5, r6] = ranges;
let [o1, o2, o3, o4, o5, o6] = offsets;
let mut storage = RefStorage::new(max_len);
storage.add(&r1, vec(r1.start as isize + o1..r1.end as isize + o1));
storage.add(&r2, vec(r2.start as isize + o2..r2.end as isize + o2));
storage.add(&r3, vec(r3.start as isize + o3..r3.end as isize + o3));
storage.add(&r4, vec(r4.start as isize + o4..r4.end as isize + o4));
storage.add(&r5, vec(r5.start as isize + o5..r5.end as isize + o5));
storage.add(&r6, vec(r6.start as isize + o6..r6.end as isize + o6));
storage
}
fn vec(range: Range<isize>) -> Vector<U8> {
let range = 8 * range.start as usize..8 * range.end as usize;
Vector::from_raw(Slice::from_range_unchecked(range))
}
}

View File

@@ -1,249 +0,0 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_common::Context;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use rangeset::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
hash::HashAlgorithm,
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingCommitment, EncodingProvider, EncodingProviderError,
EncodingTree, EncodingTreeError, new_encoder,
},
},
};
use crate::commit::transcript::TranscriptRefs;
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;
#[derive(Debug, Serialize, Deserialize)]
struct Encodings {
sent: Vec<u8>,
recv: Vec<u8>,
}
/// Transfers the encodings using the provided seed and keys.
///
/// The keys must be consistent with the global delta used in the encodings.
pub(crate) async fn transfer<'a>(
ctx: &mut Context,
refs: &TranscriptRefs,
delta: &Delta,
f: impl Fn(Vector<U8>) -> &'a [Key],
) -> Result<EncodingCommitment, EncodingError> {
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes());
let encoder = new_encoder(&secret);
let sent_keys: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
let recv_keys: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
let mut sent_encoding = Vec::with_capacity(sent_keys.len());
let mut recv_encoding = Vec::with_capacity(recv_keys.len());
encoder.encode_range(
Direction::Sent,
0..sent_keys.len() / ENCODING_SIZE,
&mut sent_encoding,
);
encoder.encode_range(
Direction::Received,
0..recv_keys.len() / ENCODING_SIZE,
&mut recv_encoding,
);
sent_encoding
.iter_mut()
.zip(sent_keys)
.for_each(|(enc, key)| *enc ^= key);
recv_encoding
.iter_mut()
.zip(recv_keys)
.for_each(|(enc, key)| *enc ^= key);
// Set frame limit and add some extra bytes cushion room.
let (sent, recv) = refs.len();
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
ctx.io_mut()
.with_limit(frame_limit)
.send(Encodings {
sent: sent_encoding,
recv: recv_encoding,
})
.await?;
let root = ctx.io_mut().expect_next().await?;
ctx.io_mut().send(secret.clone()).await?;
Ok(EncodingCommitment {
root,
secret: secret.clone(),
})
}
/// Receives the encodings using the provided MACs.
///
/// The MACs must be consistent with the global delta used in the encodings.
pub(crate) async fn receive<'a>(
ctx: &mut Context,
hasher: &(dyn HashAlgorithm + Send + Sync),
refs: &TranscriptRefs,
f: impl Fn(Vector<U8>) -> &'a [Mac],
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
// Set frame limit and add some extra bytes cushion room.
let (sent, recv) = refs.len();
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
let Encodings { mut sent, mut recv } =
ctx.io_mut().with_limit(frame_limit).expect_next().await?;
let sent_macs: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
let recv_macs: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
if sent.len() != sent_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent,
expected: sent_macs.len(),
got: sent.len(),
}
.into());
}
if recv.len() != recv_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received,
expected: recv_macs.len(),
got: recv.len(),
}
.into());
}
sent.iter_mut()
.zip(sent_macs)
.for_each(|(enc, mac)| *enc ^= mac);
recv.iter_mut()
.zip(recv_macs)
.for_each(|(enc, mac)| *enc ^= mac);
let provider = Provider { sent, recv };
let tree = EncodingTree::new(hasher, idxs, &provider)?;
let root = tree.root();
ctx.io_mut().send(root.clone()).await?;
let secret = ctx.io_mut().expect_next().await?;
let commitment = EncodingCommitment { root, secret };
Ok((commitment, tree))
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
recv: Vec<u8>,
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
if end > encodings.len() {
return Err(EncodingProviderError);
}
dest.extend_from_slice(&encodings[start..end]);
Ok(())
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
IncorrectMacCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
}
impl From<std::io::Error> for EncodingError {
fn from(value: std::io::Error) -> Self {
Self(ErrorRepr::Io(value))
}
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}

View File

@@ -7,7 +7,6 @@
pub(crate) mod commit;
pub mod config;
pub(crate) mod context;
pub(crate) mod encoding;
pub(crate) mod ghash;
pub(crate) mod msg;
pub(crate) mod mux;

View File

@@ -8,49 +8,41 @@ pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
pub use error::ProverError;
pub use future::ProverFuture;
use rustls_pki_types::CertificateDer;
pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig;
use rand::Rng;
use rustls_pki_types::CertificateDer;
use serio::SinkExt;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
ProveRequest,
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use crate::{
Role,
commit::{
commit_records,
hash::prove_hash,
transcript::{TranscriptRefs, decode_transcript},
},
commit::{ProvingState, TranscriptRefs},
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use rand::Rng;
use serio::SinkExt;
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
connection::{HandshakeData, ServerName},
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
mpz_core::Block,
@@ -173,8 +165,8 @@ impl Prover<state::Setup> {
mux_ctrl,
mut mux_fut,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
..
@@ -281,28 +273,6 @@ impl Prover<state::Setup> {
)
.map_err(ProverError::zk)?;
// Prove received plaintext. Prover drops the proof output, as
// they trust themselves.
let (sent_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
let (recv_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
@@ -310,7 +280,9 @@ impl Prover<state::Setup> {
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
let (sent_len, recv_len) = transcript.len();
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Prover {
config: self.config,
@@ -323,6 +295,10 @@ impl Prover<state::Setup> {
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred: false,
},
})
}
@@ -356,7 +332,7 @@ impl Prover<state::Committed> {
///
/// * `config` - The disclosure configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
pub async fn prove(&mut self, config: ProveConfig) -> Result<ProverOutput, ProverError> {
let state::Committed {
mux_fut,
ctx,
@@ -364,114 +340,48 @@ impl Prover<state::Committed> {
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred,
..
} = &mut self.state;
let mut output = ProverOutput {
transcript_commitments: Vec::new(),
transcript_secrets: Vec::new(),
};
// Create and send prove payload.
let server_name = self.config.server_name();
let handshake = config
.server_identity()
.then(|| (server_name.clone(), HandshakeData::new(tls_transcript)));
let partial_transcript = if let Some((sent, recv)) = config.reveal() {
decode_transcript(vm, sent, recv, transcript_refs).map_err(ProverError::zk)?;
Some(transcript.to_partial(sent.clone(), recv.clone()))
let partial = if let Some((reveal_sent, reveal_recv)) = config.reveal() {
Some(transcript.to_partial(reveal_sent.clone(), reveal_recv.clone()))
} else {
None
};
let payload = ProvePayload {
handshake: config.server_identity().then(|| {
(
self.config.server_name().clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: partial_transcript,
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
let payload = ProveRequest::new(&config, partial, handshake);
// Send payload.
mux_fut
.poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
.await?;
let mut hash_commitments = None;
if let Some(commit_config) = config.transcript_commit() {
if commit_config.has_encoding() {
let hasher: &(dyn HashAlgorithm + Send + Sync) =
match *commit_config.encoding_hash_alg() {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ProverError::config(format!(
"unsupported hash algorithm for encoding commitment: {alg}"
)));
}
};
let proving_state = ProvingState::for_prover(
config,
tls_transcript,
transcript,
transcript_refs,
*encodings_transferred,
);
let (commitment, tree) = mux_fut
.poll_with(
encoding::receive(
ctx,
hasher,
transcript_refs,
|plaintext| vm.get_macs(plaintext).expect("reference is valid"),
commit_config.iter_encoding(),
)
.map_err(ProverError::commit),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if commit_config.has_hash() {
hash_commitments = Some(
prove_hash(
vm,
transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
let (output, encodings_executed) = mux_fut
.poll_with(
proving_state
.prove(vm, ctx, zk_aes_ctr_sent, zk_aes_ctr_recv, keys.clone())
.map_err(ProverError::from),
)
.await?;
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
*encodings_transferred = encodings_executed;
Ok(output)
}

View File

@@ -1,8 +1,6 @@
use std::{error::Error, fmt};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use std::{error::Error, fmt};
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]
@@ -42,13 +40,6 @@ impl ProverError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn commit<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Commit, source)
}
}
#[derive(Debug)]
@@ -116,8 +107,8 @@ impl From<ZkAesCtrError> for ProverError {
}
}
impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self {
impl From<CommitError> for ProverError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -9,7 +9,7 @@ use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::{
commit::transcript::TranscriptRefs,
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk},
zk_aes_ctr::ZkAesCtr,
@@ -42,6 +42,10 @@ pub struct Committed {
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -1,11 +1,9 @@
//! Verifier.
pub(crate) mod config;
mod config;
mod error;
pub mod state;
use std::sync::Arc;
pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError};
pub use error::VerifierError;
pub use tlsn_core::{
@@ -13,20 +11,8 @@ pub use tlsn_core::{
webpki::ServerCertVerifier,
};
use crate::{
Role,
commit::{
commit_records,
hash::verify_hash,
transcript::{TranscriptRefs, decode_transcript, verify_transcript},
},
config::ProtocolConfig,
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
@@ -35,17 +21,25 @@ use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::stream::IoStreamExt;
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
ProveRequest,
connection::{ConnectionInfo, ServerName},
transcript::{TlsTranscript, TranscriptCommitment},
transcript::{ContentType, TlsTranscript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, info, info_span, instrument};
use crate::{
Role,
commit::{ProvingState, TranscriptRefs},
config::ProtocolConfig,
context::build_mt_context,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::ferret::Sender<mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>>,
mpz_core::Block,
@@ -188,8 +182,8 @@ impl Verifier<state::Setup> {
mut mux_fut,
delta,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
vm,
keys,
} = self.state;
@@ -230,27 +224,6 @@ impl Verifier<state::Setup> {
)
.map_err(VerifierError::zk)?;
// Prepare for the prover to prove received plaintext.
let (sent_refs, sent_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?;
@@ -260,11 +233,30 @@ impl Verifier<state::Setup> {
// authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?;
// Verify the plaintext proofs.
sent_proof.verify().map_err(VerifierError::zk)?;
recv_proof.verify().map_err(VerifierError::zk)?;
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Verifier {
config: self.config,
@@ -277,6 +269,11 @@ impl Verifier<state::Setup> {
vm,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name: None,
encodings_transferred: false,
},
})
}
@@ -305,126 +302,42 @@ impl Verifier<state::Committed> {
vm,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name,
encodings_transferred,
..
} = &mut self.state;
let ProvePayload {
handshake,
transcript,
transcript_commit,
} = mux_fut
let payload: ProveRequest = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let verifier = if let Some(root_store) = self.config.root_store() {
ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
} else {
ServerCertVerifier::mozilla()
};
let proving_state = ProvingState::for_verifier(
payload,
tls_transcript,
transcript_refs,
verified_server_name.clone(),
*encodings_transferred,
);
let server_name = if let Some((name, cert_data)) = handshake {
cert_data
.verify(
&verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
if let Some(partial_transcript) = &transcript {
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
// Check ranges.
if partial_transcript.len_sent() != sent_len
|| partial_transcript.len_received() != recv_len
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
decode_transcript(
let (output, encodings_executed) = mux_fut
.poll_with(proving_state.verify(
vm,
partial_transcript.sent_authed(),
partial_transcript.received_authed(),
transcript_refs,
)
.map_err(VerifierError::zk)?;
}
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit {
if commit_config.encoding() {
let commitment = mux_fut
.poll_with(encoding::transfer(
ctx,
transcript_refs,
delta,
|plaintext| vm.get_keys(plaintext).expect("reference is valid"),
))
.await?;
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if commit_config.has_hash() {
hash_commitments = Some(
verify_hash(vm, transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(VerifierError::zk))
ctx,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys.clone(),
*delta,
self.config.root_store(),
))
.await?;
// Verify revealed data.
if let Some(partial_transcript) = &transcript {
verify_transcript(vm, partial_transcript, transcript_refs)
.map_err(VerifierError::verify)?;
}
*verified_server_name = output.server_name.clone();
*encodings_transferred = encodings_executed;
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript,
transcript_commitments,
})
Ok(output)
}
/// Closes the connection with the prover.

View File

@@ -1,4 +1,4 @@
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
@@ -20,13 +20,6 @@ impl VerifierError {
}
}
pub(crate) fn config<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Config, source)
}
pub(crate) fn mpc<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
@@ -40,13 +33,6 @@ impl VerifierError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn verify<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Verify, source)
}
}
#[derive(Debug)]
@@ -56,7 +42,6 @@ enum ErrorKind {
Mpc,
Zk,
Commit,
Verify,
}
impl fmt::Display for VerifierError {
@@ -69,7 +54,6 @@ impl fmt::Display for VerifierError {
ErrorKind::Mpc => f.write_str("mpc error")?,
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Verify => f.write_str("verification error")?,
}
if let Some(source) = &self.source {
@@ -116,8 +100,8 @@ impl From<ZkAesCtrError> for VerifierError {
}
}
impl From<EncodingError> for VerifierError {
fn from(e: EncodingError) -> Self {
impl From<CommitError> for VerifierError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -3,14 +3,14 @@
use std::sync::Arc;
use crate::{
commit::transcript::TranscriptRefs,
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
zk_aes_ctr::ZkAesCtr,
};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_memory_core::correlated::Delta;
use tlsn_core::transcript::TlsTranscript;
use tlsn_core::{connection::ServerName, transcript::TlsTranscript};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
@@ -45,6 +45,11 @@ pub struct Committed {
pub(crate) vm: Zk,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) verified_server_name: Option<ServerName>,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -37,8 +37,8 @@ impl ZkAesCtr {
}
/// Returns the role.
pub(crate) fn role(&self) -> &Role {
&self.role
pub(crate) fn role(&self) -> Role {
self.role
}
/// Allocates `len` bytes for encryption.

View File

@@ -103,7 +103,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -126,7 +126,7 @@ impl JsProver {
let config = builder.build()?;
prover.prove(&config).await?;
prover.prove(config).await?;
prover.close().await?;
info!("Finalized");