mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-11 23:58:03 -05:00
Compare commits
1 Commits
refactor/i
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1801c30599 |
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@@ -21,8 +21,7 @@ env:
|
||||
# - https://github.com/privacy-ethereum/mpz/issues/178
|
||||
# 32 seems to be big enough for the foreseeable future
|
||||
RAYON_NUM_THREADS: 32
|
||||
RUST_VERSION: 1.92.0
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
RUST_VERSION: 1.90.0
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
|
||||
2
.github/workflows/releng.yml
vendored
2
.github/workflows/releng.yml
vendored
@@ -6,7 +6,7 @@ on:
|
||||
tag:
|
||||
description: 'Tag to publish to NPM'
|
||||
required: true
|
||||
default: 'v0.1.0-alpha.14-pre'
|
||||
default: 'v0.1.0-alpha.13'
|
||||
|
||||
jobs:
|
||||
release:
|
||||
|
||||
1
.github/workflows/updatemain.yml
vendored
1
.github/workflows/updatemain.yml
vendored
@@ -1,7 +1,6 @@
|
||||
name: Fast-forward main branch to published release tag
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
|
||||
1143
Cargo.lock
generated
1143
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
38
Cargo.toml
38
Cargo.toml
@@ -13,6 +13,7 @@ members = [
|
||||
"crates/server-fixture/server",
|
||||
"crates/tls/backend",
|
||||
"crates/tls/client",
|
||||
"crates/tls/client-async",
|
||||
"crates/tls/core",
|
||||
"crates/mpc-tls",
|
||||
"crates/tls/server-fixture",
|
||||
@@ -56,6 +57,7 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
|
||||
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
|
||||
tlsn-tls-backend = { path = "crates/tls/backend" }
|
||||
tlsn-tls-client = { path = "crates/tls/client" }
|
||||
tlsn-tls-client-async = { path = "crates/tls/client-async" }
|
||||
tlsn-tls-core = { path = "crates/tls/core" }
|
||||
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
tlsn-harness-core = { path = "crates/harness/core" }
|
||||
@@ -64,28 +66,26 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
|
||||
tlsn-wasm = { path = "crates/wasm" }
|
||||
tlsn = { path = "crates/tlsn" }
|
||||
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
|
||||
|
||||
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "c210f2f" }
|
||||
rangeset = { version = "0.4" }
|
||||
rangeset = { version = "0.2" }
|
||||
serio = { version = "0.2" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
uid-mux = { version = "0.2" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
|
||||
aead = { version = "0.4" }
|
||||
aes = { version = "0.8" }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tlsn-attestation"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2024"
|
||||
|
||||
[features]
|
||||
@@ -9,7 +9,7 @@ fixtures = ["tlsn-core/fixtures", "dep:tlsn-data-fixtures"]
|
||||
|
||||
[dependencies]
|
||||
tlsn-tls-core = { workspace = true }
|
||||
tlsn-core = { workspace = true, features = ["mozilla-certs"] }
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-data-fixtures = { workspace = true, optional = true }
|
||||
|
||||
bcs = { workspace = true }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//! Attestation fixtures.
|
||||
|
||||
use tlsn_core::{
|
||||
connection::{CertBinding, CertBindingV1_2},
|
||||
fixtures::ConnectionFixture,
|
||||
@@ -12,10 +13,7 @@ use tlsn_core::{
|
||||
use crate::{
|
||||
Attestation, AttestationConfig, CryptoProvider, Extension,
|
||||
request::{Request, RequestConfig},
|
||||
signing::{
|
||||
KeyAlgId, SignatureAlgId, SignatureVerifier, SignatureVerifierProvider, Signer,
|
||||
SignerProvider,
|
||||
},
|
||||
signing::SignatureAlgId,
|
||||
};
|
||||
|
||||
/// A Request fixture used for testing.
|
||||
@@ -104,8 +102,7 @@ pub fn attestation_fixture(
|
||||
let mut provider = CryptoProvider::default();
|
||||
match signature_alg {
|
||||
SignatureAlgId::SECP256K1 => provider.signer.set_secp256k1(&[42u8; 32]).unwrap(),
|
||||
SignatureAlgId::SECP256K1ETH => provider.signer.set_secp256k1eth(&[43u8; 32]).unwrap(),
|
||||
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[44u8; 32]).unwrap(),
|
||||
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[42u8; 32]).unwrap(),
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
@@ -125,68 +122,3 @@ pub fn attestation_fixture(
|
||||
|
||||
attestation_builder.build(&provider).unwrap()
|
||||
}
|
||||
|
||||
/// Returns a crypto provider which supports only a custom signature alg.
|
||||
pub fn custom_provider_fixture() -> CryptoProvider {
|
||||
const CUSTOM_SIG_ALG_ID: SignatureAlgId = SignatureAlgId::new(128);
|
||||
|
||||
// A dummy signer.
|
||||
struct DummySigner {}
|
||||
impl Signer for DummySigner {
|
||||
fn alg_id(&self) -> SignatureAlgId {
|
||||
CUSTOM_SIG_ALG_ID
|
||||
}
|
||||
|
||||
fn sign(
|
||||
&self,
|
||||
msg: &[u8],
|
||||
) -> Result<crate::signing::Signature, crate::signing::SignatureError> {
|
||||
Ok(crate::signing::Signature {
|
||||
alg: CUSTOM_SIG_ALG_ID,
|
||||
data: msg.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
fn verifying_key(&self) -> crate::signing::VerifyingKey {
|
||||
crate::signing::VerifyingKey {
|
||||
alg: KeyAlgId::new(128),
|
||||
data: vec![1, 2, 3, 4],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A dummy verifier.
|
||||
struct DummyVerifier {}
|
||||
impl SignatureVerifier for DummyVerifier {
|
||||
fn alg_id(&self) -> SignatureAlgId {
|
||||
CUSTOM_SIG_ALG_ID
|
||||
}
|
||||
|
||||
fn verify(
|
||||
&self,
|
||||
_key: &crate::signing::VerifyingKey,
|
||||
msg: &[u8],
|
||||
sig: &[u8],
|
||||
) -> Result<(), crate::signing::SignatureError> {
|
||||
if msg == sig {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(crate::signing::SignatureError::from_str(
|
||||
"invalid signature",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut provider = CryptoProvider::default();
|
||||
|
||||
let mut signer_provider = SignerProvider::default();
|
||||
signer_provider.set_signer(Box::new(DummySigner {}));
|
||||
provider.signer = signer_provider;
|
||||
|
||||
let mut verifier_provider = SignatureVerifierProvider::empty();
|
||||
verifier_provider.set_verifier(Box::new(DummyVerifier {}));
|
||||
provider.signature = verifier_provider;
|
||||
|
||||
provider
|
||||
}
|
||||
|
||||
@@ -20,10 +20,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use tlsn_core::hash::HashAlgId;
|
||||
|
||||
use crate::{
|
||||
Attestation, CryptoProvider, Extension, connection::ServerCertCommitment,
|
||||
serialize::CanonicalSerialize, signing::SignatureAlgId,
|
||||
};
|
||||
use crate::{Attestation, Extension, connection::ServerCertCommitment, signing::SignatureAlgId};
|
||||
|
||||
pub use builder::{RequestBuilder, RequestBuilderError};
|
||||
pub use config::{RequestConfig, RequestConfigBuilder, RequestConfigBuilderError};
|
||||
@@ -44,102 +41,44 @@ impl Request {
|
||||
}
|
||||
|
||||
/// Validates the content of the attestation against this request.
|
||||
pub fn validate(
|
||||
&self,
|
||||
attestation: &Attestation,
|
||||
provider: &CryptoProvider,
|
||||
) -> Result<(), AttestationValidationError> {
|
||||
pub fn validate(&self, attestation: &Attestation) -> Result<(), InconsistentAttestation> {
|
||||
if attestation.signature.alg != self.signature_alg {
|
||||
return Err(AttestationValidationError::inconsistent(format!(
|
||||
return Err(InconsistentAttestation(format!(
|
||||
"signature algorithm: expected {:?}, got {:?}",
|
||||
self.signature_alg, attestation.signature.alg
|
||||
)));
|
||||
}
|
||||
|
||||
if attestation.header.root.alg != self.hash_alg {
|
||||
return Err(AttestationValidationError::inconsistent(format!(
|
||||
return Err(InconsistentAttestation(format!(
|
||||
"hash algorithm: expected {:?}, got {:?}",
|
||||
self.hash_alg, attestation.header.root.alg
|
||||
)));
|
||||
}
|
||||
|
||||
if attestation.body.cert_commitment() != &self.server_cert_commitment {
|
||||
return Err(AttestationValidationError::inconsistent(
|
||||
"server certificate commitment does not match",
|
||||
return Err(InconsistentAttestation(
|
||||
"server certificate commitment does not match".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: improve the O(M*N) complexity of this check.
|
||||
for extension in &self.extensions {
|
||||
if !attestation.body.extensions().any(|e| e == extension) {
|
||||
return Err(AttestationValidationError::inconsistent(
|
||||
"extension is missing from the attestation",
|
||||
return Err(InconsistentAttestation(
|
||||
"extension is missing from the attestation".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let verifier = provider
|
||||
.signature
|
||||
.get(&attestation.signature.alg)
|
||||
.map_err(|_| {
|
||||
AttestationValidationError::provider(format!(
|
||||
"provider not configured for signature algorithm id {:?}",
|
||||
attestation.signature.alg,
|
||||
))
|
||||
})?;
|
||||
|
||||
verifier
|
||||
.verify(
|
||||
&attestation.body.verifying_key.data,
|
||||
&CanonicalSerialize::serialize(&attestation.header),
|
||||
&attestation.signature.data,
|
||||
)
|
||||
.map_err(|_| {
|
||||
AttestationValidationError::inconsistent("failed to verify the signature")
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`Request::validate`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("attestation validation error: {kind}: {message}")]
|
||||
pub struct AttestationValidationError {
|
||||
kind: ErrorKind,
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl AttestationValidationError {
|
||||
fn inconsistent(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Inconsistent,
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Provider,
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ErrorKind {
|
||||
Inconsistent,
|
||||
Provider,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ErrorKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ErrorKind::Inconsistent => write!(f, "inconsistent"),
|
||||
ErrorKind::Provider => write!(f, "provider"),
|
||||
}
|
||||
}
|
||||
}
|
||||
#[error("inconsistent attestation: {0}")]
|
||||
pub struct InconsistentAttestation(String);
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
@@ -154,8 +93,7 @@ mod test {
|
||||
use crate::{
|
||||
CryptoProvider,
|
||||
connection::ServerCertOpening,
|
||||
fixtures::{RequestFixture, attestation_fixture, custom_provider_fixture, request_fixture},
|
||||
request::{AttestationValidationError, ErrorKind},
|
||||
fixtures::{RequestFixture, attestation_fixture, request_fixture},
|
||||
signing::SignatureAlgId,
|
||||
};
|
||||
|
||||
@@ -175,9 +113,7 @@ mod test {
|
||||
let attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
|
||||
let provider = CryptoProvider::default();
|
||||
|
||||
assert!(request.validate(&attestation, &provider).is_ok())
|
||||
assert!(request.validate(&attestation).is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -198,9 +134,7 @@ mod test {
|
||||
|
||||
request.signature_alg = SignatureAlgId::SECP256R1;
|
||||
|
||||
let provider = CryptoProvider::default();
|
||||
|
||||
let res = request.validate(&attestation, &provider);
|
||||
let res = request.validate(&attestation);
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
@@ -222,9 +156,7 @@ mod test {
|
||||
|
||||
request.hash_alg = HashAlgId::SHA256;
|
||||
|
||||
let provider = CryptoProvider::default();
|
||||
|
||||
let res = request.validate(&attestation, &provider);
|
||||
let res = request.validate(&attestation);
|
||||
assert!(res.is_err())
|
||||
}
|
||||
|
||||
@@ -252,62 +184,11 @@ mod test {
|
||||
});
|
||||
let opening = ServerCertOpening::new(server_cert_data);
|
||||
|
||||
let provider = CryptoProvider::default();
|
||||
let crypto_provider = CryptoProvider::default();
|
||||
request.server_cert_commitment =
|
||||
opening.commit(provider.hash.get(&HashAlgId::BLAKE3).unwrap());
|
||||
opening.commit(crypto_provider.hash.get(&HashAlgId::BLAKE3).unwrap());
|
||||
|
||||
let res = request.validate(&attestation, &provider);
|
||||
let res = request.validate(&attestation);
|
||||
assert!(res.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_sig() {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let mut attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
|
||||
// Corrupt the signature.
|
||||
attestation.signature.data[1] = attestation.signature.data[1].wrapping_add(1);
|
||||
|
||||
let provider = CryptoProvider::default();
|
||||
|
||||
assert!(request.validate(&attestation, &provider).is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_provider() {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
|
||||
let provider = custom_provider_fixture();
|
||||
|
||||
assert!(matches!(
|
||||
request.validate(&attestation, &provider),
|
||||
Err(AttestationValidationError {
|
||||
kind: ErrorKind::Provider,
|
||||
..
|
||||
})
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,14 +202,6 @@ impl SignatureVerifierProvider {
|
||||
.map(|s| &**s)
|
||||
.ok_or(UnknownSignatureAlgId(*alg))
|
||||
}
|
||||
|
||||
/// Returns am empty provider.
|
||||
#[cfg(any(test, feature = "fixtures"))]
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
verifiers: HashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Signature verifier.
|
||||
@@ -237,14 +229,6 @@ impl_domain_separator!(VerifyingKey);
|
||||
#[error("signature verification failed: {0}")]
|
||||
pub struct SignatureError(String);
|
||||
|
||||
impl SignatureError {
|
||||
/// Creates a new error with the given message.
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn from_str(msg: &str) -> Self {
|
||||
Self(msg.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// A signature.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Signature {
|
||||
|
||||
@@ -101,7 +101,7 @@ fn test_api() {
|
||||
let attestation = attestation_builder.build(&provider).unwrap();
|
||||
|
||||
// Prover validates the attestation is consistent with its request.
|
||||
request.validate(&attestation, &provider).unwrap();
|
||||
request.validate(&attestation).unwrap();
|
||||
|
||||
let mut transcript_proof_builder = secrets.transcript_proof_builder();
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "This crate provides implementations of ciphers for two parties"
|
||||
keywords = ["tls", "mpc", "2pc", "aes"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
@@ -15,7 +15,7 @@ workspace = true
|
||||
name = "cipher"
|
||||
|
||||
[dependencies]
|
||||
mpz-circuits = { workspace = true, features = ["aes"] }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
|
||||
@@ -24,9 +24,11 @@ thiserror = { workspace = true }
|
||||
aes = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-common = { workspace = true, features = ["test-utils"] }
|
||||
mpz-ideal-vm = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
|
||||
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rand = { workspace = true }
|
||||
ctr = { workspace = true }
|
||||
cipher = { workspace = true }
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use crate::{Cipher, CtrBlock, Keystream};
|
||||
use async_trait::async_trait;
|
||||
use mpz_circuits::{AES128_KS, AES128_POST_KS};
|
||||
use mpz_circuits::circuits::AES128;
|
||||
use mpz_memory_core::binary::{Binary, U8};
|
||||
use mpz_vm_core::{prelude::*, Call, Vm};
|
||||
use std::fmt::Debug;
|
||||
@@ -12,35 +12,13 @@ mod error;
|
||||
pub use error::AesError;
|
||||
use error::ErrorKind;
|
||||
|
||||
/// AES key schedule: 11 round keys, 16 bytes each.
|
||||
type KeySchedule = Array<U8, 176>;
|
||||
|
||||
/// Computes AES-128.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct Aes128 {
|
||||
key: Option<Array<U8, 16>>,
|
||||
key_schedule: Option<KeySchedule>,
|
||||
iv: Option<Array<U8, 4>>,
|
||||
}
|
||||
|
||||
impl Aes128 {
|
||||
// Allocates key schedule.
|
||||
//
|
||||
// Expects the key to be already set.
|
||||
fn alloc_key_schedule(&self, vm: &mut dyn Vm<Binary>) -> Result<KeySchedule, AesError> {
|
||||
let ks: KeySchedule = vm
|
||||
.call(
|
||||
Call::builder(AES128_KS.clone())
|
||||
.arg(self.key.expect("key is set"))
|
||||
.build()
|
||||
.expect("call should be valid"),
|
||||
)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
Ok(ks)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Cipher for Aes128 {
|
||||
type Error = AesError;
|
||||
@@ -67,22 +45,18 @@ impl Cipher for Aes128 {
|
||||
}
|
||||
|
||||
fn alloc_block(
|
||||
&mut self,
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
input: Array<U8, 16>,
|
||||
) -> Result<Self::Block, Self::Error> {
|
||||
self.key
|
||||
let key = self
|
||||
.key
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
|
||||
|
||||
if self.key_schedule.is_none() {
|
||||
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
|
||||
}
|
||||
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
|
||||
|
||||
let output = vm
|
||||
.call(
|
||||
Call::builder(AES128_POST_KS.clone())
|
||||
.arg(ks)
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(input)
|
||||
.build()
|
||||
.expect("call should be valid"),
|
||||
@@ -93,10 +67,11 @@ impl Cipher for Aes128 {
|
||||
}
|
||||
|
||||
fn alloc_ctr_block(
|
||||
&mut self,
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
|
||||
self.key
|
||||
let key = self
|
||||
.key
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
|
||||
let iv = self
|
||||
.iv
|
||||
@@ -114,15 +89,10 @@ impl Cipher for Aes128 {
|
||||
vm.mark_public(counter)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
if self.key_schedule.is_none() {
|
||||
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
|
||||
}
|
||||
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
|
||||
|
||||
let output = vm
|
||||
.call(
|
||||
Call::builder(AES128_POST_KS.clone())
|
||||
.arg(ks)
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(iv)
|
||||
.arg(explicit_nonce)
|
||||
.arg(counter)
|
||||
@@ -139,11 +109,12 @@ impl Cipher for Aes128 {
|
||||
}
|
||||
|
||||
fn alloc_keystream(
|
||||
&mut self,
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
len: usize,
|
||||
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
|
||||
self.key
|
||||
let key = self
|
||||
.key
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
|
||||
let iv = self
|
||||
.iv
|
||||
@@ -172,15 +143,10 @@ impl Cipher for Aes128 {
|
||||
let blocks = inputs
|
||||
.into_iter()
|
||||
.map(|(explicit_nonce, counter)| {
|
||||
if self.key_schedule.is_none() {
|
||||
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
|
||||
}
|
||||
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
|
||||
|
||||
let output = vm
|
||||
.call(
|
||||
Call::builder(AES128_POST_KS.clone())
|
||||
.arg(ks)
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(iv)
|
||||
.arg(explicit_nonce)
|
||||
.arg(counter)
|
||||
@@ -206,12 +172,15 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::Cipher;
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
correlated::Delta,
|
||||
Array, MemoryExt, Vector, ViewExt,
|
||||
};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{Execute, Vm};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aes_ctr() {
|
||||
@@ -221,11 +190,10 @@ mod tests {
|
||||
let start_counter = 3u32;
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut gen = IdealVm::new();
|
||||
let mut ev = IdealVm::new();
|
||||
let (mut gen, mut ev) = mock_vm();
|
||||
|
||||
let mut aes_gen = setup_ctr(key, iv, &mut gen);
|
||||
let mut aes_ev = setup_ctr(key, iv, &mut ev);
|
||||
let aes_gen = setup_ctr(key, iv, &mut gen);
|
||||
let aes_ev = setup_ctr(key, iv, &mut ev);
|
||||
|
||||
let msg = vec![42u8; 128];
|
||||
|
||||
@@ -284,11 +252,10 @@ mod tests {
|
||||
let input = [5_u8; 16];
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut gen = IdealVm::new();
|
||||
let mut ev = IdealVm::new();
|
||||
let (mut gen, mut ev) = mock_vm();
|
||||
|
||||
let mut aes_gen = setup_block(key, &mut gen);
|
||||
let mut aes_ev = setup_block(key, &mut ev);
|
||||
let aes_gen = setup_block(key, &mut gen);
|
||||
let aes_ev = setup_block(key, &mut ev);
|
||||
|
||||
let block_ref_gen: Array<U8, 16> = gen.alloc().unwrap();
|
||||
gen.mark_public(block_ref_gen).unwrap();
|
||||
@@ -327,6 +294,18 @@ mod tests {
|
||||
assert_eq!(ciphertext_gen, expected);
|
||||
}
|
||||
|
||||
fn mock_vm() -> (impl Vm<Binary>, impl Vm<Binary>) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gen = Garbler::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
|
||||
(gen, ev)
|
||||
}
|
||||
|
||||
fn setup_ctr(key: [u8; 16], iv: [u8; 4], vm: &mut dyn Vm<Binary>) -> Aes128 {
|
||||
let key_ref: Array<U8, 16> = vm.alloc().unwrap();
|
||||
vm.mark_public(key_ref).unwrap();
|
||||
|
||||
@@ -55,7 +55,7 @@ pub trait Cipher {
|
||||
|
||||
/// Allocates a single block in ECB mode.
|
||||
fn alloc_block(
|
||||
&mut self,
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
input: Self::Block,
|
||||
) -> Result<Self::Block, Self::Error>;
|
||||
@@ -63,7 +63,7 @@ pub trait Cipher {
|
||||
/// Allocates a single block in counter mode.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn alloc_ctr_block(
|
||||
&mut self,
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
|
||||
|
||||
@@ -75,7 +75,7 @@ pub trait Cipher {
|
||||
/// * `len` - Length of the stream in bytes.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn alloc_keystream(
|
||||
&mut self,
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
len: usize,
|
||||
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tlsn-deap"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
@@ -19,8 +19,11 @@ futures = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-circuits = { workspace = true, features = ["aes"] }
|
||||
mpz-common = { workspace = true, features = ["test-utils"] }
|
||||
mpz-ideal-vm = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-zk = { workspace = true }
|
||||
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rand = { workspace = true }
|
||||
rand06-compat = { workspace = true }
|
||||
|
||||
@@ -15,7 +15,7 @@ use mpz_vm_core::{
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
|
||||
Call, Callable, Execute, Vm, VmError,
|
||||
};
|
||||
use rangeset::{ops::Set, set::RangeSet};
|
||||
use rangeset::{Difference, RangeSet, UnionMut};
|
||||
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
|
||||
|
||||
type Error = DeapError;
|
||||
@@ -210,12 +210,10 @@ where
|
||||
}
|
||||
|
||||
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
let slice_range = slice.to_range();
|
||||
|
||||
// Follower's private inputs are not committed in the ZK VM until finalization.
|
||||
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
for input in input_minus_follower {
|
||||
for input in input_minus_follower.iter_ranges() {
|
||||
zk.commit_raw(
|
||||
self.memory_map
|
||||
.try_get(Slice::from_range_unchecked(input))?,
|
||||
@@ -268,7 +266,7 @@ where
|
||||
mpc.mark_private_raw(slice)?;
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(slice.to_range());
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
}
|
||||
@@ -284,7 +282,7 @@ where
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(slice.to_range());
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
Role::Follower => {
|
||||
@@ -384,27 +382,37 @@ enum ErrorRepr {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_circuits::AES128;
|
||||
use mpz_circuits::circuits::AES128;
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_core::Block;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::{cot::ideal_cot, rcot::ideal_rcot};
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array},
|
||||
memory::{binary::U8, correlated::Delta, Array},
|
||||
prelude::*,
|
||||
};
|
||||
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deap() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta_mpc = Delta::random(&mut rng);
|
||||
let delta_zk = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
|
||||
|
||||
let leader_mpc = IdealVm::new();
|
||||
let leader_zk = IdealVm::new();
|
||||
let follower_mpc = IdealVm::new();
|
||||
let follower_zk = IdealVm::new();
|
||||
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(ProverConfig::default(), rcot_recv);
|
||||
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
|
||||
|
||||
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
|
||||
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
|
||||
let mut leader = Deap::new(Role::Leader, gb, prover);
|
||||
let mut follower = Deap::new(Role::Follower, ev, verifier);
|
||||
|
||||
let (ct_leader, ct_follower) = futures::join!(
|
||||
async {
|
||||
@@ -470,15 +478,21 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deap_desync_memory() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta_mpc = Delta::random(&mut rng);
|
||||
let delta_zk = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
|
||||
|
||||
let leader_mpc = IdealVm::new();
|
||||
let leader_zk = IdealVm::new();
|
||||
let follower_mpc = IdealVm::new();
|
||||
let follower_zk = IdealVm::new();
|
||||
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(ProverConfig::default(), rcot_recv);
|
||||
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
|
||||
|
||||
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
|
||||
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
|
||||
let mut leader = Deap::new(Role::Leader, gb, prover);
|
||||
let mut follower = Deap::new(Role::Follower, ev, verifier);
|
||||
|
||||
// Desynchronize the memories.
|
||||
let _ = leader.zk().alloc_raw(1).unwrap();
|
||||
@@ -550,15 +564,21 @@ mod tests {
|
||||
// detection by the follower.
|
||||
#[tokio::test]
|
||||
async fn test_malicious() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta_mpc = Delta::random(&mut rng);
|
||||
let delta_zk = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
|
||||
|
||||
let leader_mpc = IdealVm::new();
|
||||
let leader_zk = IdealVm::new();
|
||||
let follower_mpc = IdealVm::new();
|
||||
let follower_zk = IdealVm::new();
|
||||
let gb = Garbler::new(cot_send, [1u8; 16], delta_mpc);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(ProverConfig::default(), rcot_recv);
|
||||
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
|
||||
|
||||
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
|
||||
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
|
||||
let mut leader = Deap::new(Role::Leader, gb, prover);
|
||||
let mut follower = Deap::new(Role::Follower, ev, verifier);
|
||||
|
||||
let (_, follower_res) = futures::join!(
|
||||
async {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_vm_core::{memory::Slice, VmError};
|
||||
use rangeset::ops::Set;
|
||||
use rangeset::Subset;
|
||||
|
||||
/// A mapping between the memories of the MPC and ZK VMs.
|
||||
#[derive(Debug, Default)]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "A 2PC implementation of TLS HMAC-SHA256 PRF"
|
||||
keywords = ["tls", "mpc", "2pc", "hmac", "sha256"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
@@ -20,13 +20,14 @@ mpz-core = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-hash = { workspace = true }
|
||||
|
||||
sha2 = { workspace = true, features = ["compress"] }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-ot = { workspace = true, features = ["ideal"] }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-common = { workspace = true, features = ["test-utils"] }
|
||||
mpz-ideal-vm = { workspace = true }
|
||||
|
||||
criterion = { workspace = true, features = ["async_tokio"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
|
||||
@@ -4,12 +4,14 @@ use criterion::{criterion_group, criterion_main, Criterion};
|
||||
|
||||
use hmac_sha256::{Mode, MpcPrf};
|
||||
use mpz_common::context::test_mt_context;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array},
|
||||
memory::{binary::U8, correlated::Delta, Array},
|
||||
prelude::*,
|
||||
Execute,
|
||||
};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
#[allow(clippy::unit_arg)]
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
@@ -27,6 +29,8 @@ criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
|
||||
async fn prf(mode: Mode) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let pms = [42u8; 32];
|
||||
let client_random = [69u8; 32];
|
||||
let server_random: [u8; 32] = [96u8; 32];
|
||||
@@ -35,8 +39,11 @@ async fn prf(mode: Mode) {
|
||||
let mut leader_ctx = leader_exec.new_context().await.unwrap();
|
||||
let mut follower_ctx = follower_exec.new_context().await.unwrap();
|
||||
|
||||
let mut leader_vm = IdealVm::new();
|
||||
let mut follower_vm = IdealVm::new();
|
||||
let delta = Delta::random(&mut rng);
|
||||
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta);
|
||||
let mut follower_vm = Evaluator::new(ot_recv);
|
||||
|
||||
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
|
||||
leader_vm.mark_public(leader_pms).unwrap();
|
||||
|
||||
@@ -54,11 +54,10 @@ mod tests {
|
||||
use crate::{
|
||||
hmac::hmac_sha256,
|
||||
sha256, state_to_bytes,
|
||||
test_utils::{compute_inner_local, compute_outer_partial},
|
||||
test_utils::{compute_inner_local, compute_outer_partial, mock_vm},
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{U32, U8},
|
||||
@@ -84,8 +83,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_hmac_circuit() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut leader = IdealVm::new();
|
||||
let mut follower = IdealVm::new();
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let (inputs, references) = test_fixtures();
|
||||
for (input, &reference) in inputs.iter().zip(references.iter()) {
|
||||
|
||||
@@ -72,11 +72,10 @@ fn state_to_bytes(input: [u32; 8]) -> [u8; 32] {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
test_utils::{prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
|
||||
test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
|
||||
Mode, MpcPrf, SessionKeys,
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array, MemoryExt, ViewExt},
|
||||
Execute,
|
||||
@@ -124,8 +123,7 @@ mod tests {
|
||||
|
||||
// Set up vm and prf
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(128);
|
||||
let mut leader = IdealVm::new();
|
||||
let mut follower = IdealVm::new();
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let leader_pms: Array<U8, 32> = leader.alloc().unwrap();
|
||||
leader.mark_public(leader_pms).unwrap();
|
||||
|
||||
@@ -339,9 +339,8 @@ fn gen_merge_circ(size: usize) -> Arc<Circuit> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::prf::merge_outputs;
|
||||
use crate::{prf::merge_outputs, test_utils::mock_vm};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array, MemoryExt, ViewExt},
|
||||
Execute,
|
||||
@@ -350,8 +349,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_merge_outputs() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut leader = IdealVm::new();
|
||||
let mut follower = IdealVm::new();
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let input1: [u8; 32] = std::array::from_fn(|i| i as u8);
|
||||
let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32);
|
||||
|
||||
@@ -137,11 +137,10 @@ impl Prf {
|
||||
mod tests {
|
||||
use crate::{
|
||||
prf::{compute_partial, function::Prf},
|
||||
test_utils::phash,
|
||||
test_utils::{mock_vm, phash},
|
||||
Mode,
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array, MemoryExt, ViewExt},
|
||||
Execute,
|
||||
@@ -167,8 +166,7 @@ mod tests {
|
||||
let mut rng = ThreadRng::default();
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut leader = IdealVm::new();
|
||||
let mut follower = IdealVm::new();
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let key: [u8; 32] = rng.random();
|
||||
let start_seed: Vec<u8> = vec![42; 64];
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
use crate::{sha256, state_to_bytes};
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
|
||||
use mpz_vm_core::memory::correlated::Delta;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
pub(crate) const SHA256_IV: [u32; 8] = [
|
||||
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
|
||||
];
|
||||
|
||||
pub(crate) fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gen = Garbler::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
|
||||
(gen, ev)
|
||||
}
|
||||
|
||||
pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
|
||||
let mut label_start_seed = b"master secret".to_vec();
|
||||
label_start_seed.extend_from_slice(&client_random);
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Implementation of the 3-party key-exchange protocol"
|
||||
keywords = ["tls", "mpc", "2pc", "pms", "key-exchange"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
@@ -40,7 +40,6 @@ tokio = { workspace = true, features = ["sync"] }
|
||||
[dev-dependencies]
|
||||
mpz-ot = { workspace = true, features = ["ideal"] }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-ideal-vm = { workspace = true }
|
||||
|
||||
rand_core = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
|
||||
@@ -459,7 +459,9 @@ mod tests {
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_fields::UniformRand;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_memory_core::correlated::Delta;
|
||||
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
|
||||
use mpz_share_conversion::ideal::{
|
||||
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
|
||||
};
|
||||
@@ -482,8 +484,7 @@ mod tests {
|
||||
async fn test_key_exchange() {
|
||||
let mut rng = StdRng::seed_from_u64(0).compat();
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut gen = IdealVm::new();
|
||||
let mut ev = IdealVm::new();
|
||||
let (mut gen, mut ev) = mock_vm();
|
||||
|
||||
let leader_private_key = SecretKey::random(&mut rng);
|
||||
let follower_private_key = SecretKey::random(&mut rng);
|
||||
@@ -624,8 +625,7 @@ mod tests {
|
||||
async fn test_malicious_key_exchange(#[case] malicious: Malicious) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut gen = IdealVm::new();
|
||||
let mut ev = IdealVm::new();
|
||||
let (mut gen, mut ev) = mock_vm();
|
||||
|
||||
let leader_private_key = SecretKey::random(&mut rng.compat_by_ref());
|
||||
let follower_private_key = SecretKey::random(&mut rng.compat_by_ref());
|
||||
@@ -704,8 +704,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_circuit() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let gen = IdealVm::new();
|
||||
let ev = IdealVm::new();
|
||||
let (gen, ev) = mock_vm();
|
||||
|
||||
let share_a0_bytes = [5_u8; 32];
|
||||
let share_a1_bytes = [2_u8; 32];
|
||||
@@ -835,4 +834,16 @@ mod tests {
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
|
||||
fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gen = Garbler::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
|
||||
(gen, ev)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
//! with the server alone and forward all messages from and to the follower.
|
||||
//!
|
||||
//! A detailed description of this protocol can be found in our documentation
|
||||
//! <https://tlsnotary.org/docs/mpc/key_exchange>.
|
||||
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
|
||||
@@ -26,7 +26,8 @@ pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender};
|
||||
|
||||
use super::*;
|
||||
use crate::KeyExchange;
|
||||
@@ -39,12 +40,12 @@ mod tests {
|
||||
|
||||
is_key_exchange::<
|
||||
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
|
||||
IdealVm,
|
||||
Garbler<IdealCOTSender>,
|
||||
>(leader);
|
||||
|
||||
is_key_exchange::<
|
||||
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
|
||||
IdealVm,
|
||||
Evaluator<IdealCOTReceiver>,
|
||||
>(follower);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! protocol has semi-honest security.
|
||||
//!
|
||||
//! The protocol is described in
|
||||
//! <https://tlsnotary.org/docs/mpc/key_exchange>
|
||||
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>
|
||||
|
||||
use crate::{KeyExchangeError, Role};
|
||||
use mpz_common::{Context, Flush};
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Core types for TLSNotary"
|
||||
keywords = ["tls", "mpc", "2pc", "types"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
@@ -13,7 +13,6 @@ workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
mozilla-certs = ["dep:webpki-root-certs", "dep:webpki-roots"]
|
||||
fixtures = [
|
||||
"dep:hex",
|
||||
"dep:tlsn-data-fixtures",
|
||||
@@ -45,8 +44,7 @@ sha2 = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tiny-keccak = { workspace = true, features = ["keccak"] }
|
||||
web-time = { workspace = true }
|
||||
webpki-roots = { workspace = true, optional = true }
|
||||
webpki-root-certs = { workspace = true, optional = true }
|
||||
webpki-roots = { workspace = true }
|
||||
rustls-webpki = { workspace = true, features = ["ring"] }
|
||||
rustls-pki-types = { workspace = true }
|
||||
itybity = { workspace = true }
|
||||
@@ -59,7 +57,5 @@ generic-array = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tlsn-core = { workspace = true, features = ["fixtures"] }
|
||||
tlsn-attestation = { workspace = true, features = ["fixtures"] }
|
||||
tlsn-data-fixtures = { workspace = true }
|
||||
webpki-root-certs = { workspace = true }
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
//! Configuration types.
|
||||
|
||||
pub mod prove;
|
||||
pub mod prover;
|
||||
pub mod tls;
|
||||
pub mod tls_commit;
|
||||
pub mod verifier;
|
||||
@@ -1,189 +0,0 @@
|
||||
//! Proving configuration.
|
||||
|
||||
use rangeset::set::{RangeSet, ToRangeSet};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};
|
||||
|
||||
/// Configuration to prove information to the verifier.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProveConfig {
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitConfig>,
|
||||
}
|
||||
|
||||
impl ProveConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
|
||||
ProveConfigBuilder::new(transcript)
|
||||
}
|
||||
|
||||
/// Returns `true` if the server identity is to be proven.
|
||||
pub fn server_identity(&self) -> bool {
|
||||
self.server_identity
|
||||
}
|
||||
|
||||
/// Returns the sent and received ranges of the transcript to be revealed,
|
||||
/// respectively.
|
||||
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
|
||||
self.reveal.as_ref()
|
||||
}
|
||||
|
||||
/// Returns the transcript commitment configuration.
|
||||
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
|
||||
self.transcript_commit.as_ref()
|
||||
}
|
||||
|
||||
/// Returns a request.
|
||||
pub fn to_request(&self) -> ProveRequest {
|
||||
ProveRequest {
|
||||
server_identity: self.server_identity,
|
||||
reveal: self.reveal.clone(),
|
||||
transcript_commit: self
|
||||
.transcript_commit
|
||||
.clone()
|
||||
.map(|config| config.to_request()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`ProveConfig`].
|
||||
#[derive(Debug)]
|
||||
pub struct ProveConfigBuilder<'a> {
|
||||
transcript: &'a Transcript,
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitConfig>,
|
||||
}
|
||||
|
||||
impl<'a> ProveConfigBuilder<'a> {
|
||||
/// Creates a new builder.
|
||||
pub fn new(transcript: &'a Transcript) -> Self {
|
||||
Self {
|
||||
transcript,
|
||||
server_identity: false,
|
||||
reveal: None,
|
||||
transcript_commit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Proves the server identity.
|
||||
pub fn server_identity(&mut self) -> &mut Self {
|
||||
self.server_identity = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures transcript commitments.
|
||||
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
|
||||
self.transcript_commit = Some(transcript_commit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Reveals the given ranges of the transcript.
|
||||
pub fn reveal(
|
||||
&mut self,
|
||||
direction: Direction,
|
||||
ranges: &dyn ToRangeSet<usize>,
|
||||
) -> Result<&mut Self, ProveConfigError> {
|
||||
let idx = ranges.to_range_set();
|
||||
|
||||
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
|
||||
return Err(ProveConfigError(ErrorRepr::IndexOutOfBounds {
|
||||
direction,
|
||||
actual: idx.end().unwrap_or(0),
|
||||
len: self.transcript.len_of_direction(direction),
|
||||
}));
|
||||
}
|
||||
|
||||
let (sent, recv) = self.reveal.get_or_insert_default();
|
||||
match direction {
|
||||
Direction::Sent => sent.union_mut(&idx),
|
||||
Direction::Received => recv.union_mut(&idx),
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Reveals the given ranges of the sent data transcript.
|
||||
pub fn reveal_sent(
|
||||
&mut self,
|
||||
ranges: &dyn ToRangeSet<usize>,
|
||||
) -> Result<&mut Self, ProveConfigError> {
|
||||
self.reveal(Direction::Sent, ranges)
|
||||
}
|
||||
|
||||
/// Reveals all of the sent data transcript.
|
||||
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigError> {
|
||||
let len = self.transcript.len_of_direction(Direction::Sent);
|
||||
let (sent, _) = self.reveal.get_or_insert_default();
|
||||
sent.union_mut(&(0..len));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Reveals the given ranges of the received data transcript.
|
||||
pub fn reveal_recv(
|
||||
&mut self,
|
||||
ranges: &dyn ToRangeSet<usize>,
|
||||
) -> Result<&mut Self, ProveConfigError> {
|
||||
self.reveal(Direction::Received, ranges)
|
||||
}
|
||||
|
||||
/// Reveals all of the received data transcript.
|
||||
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigError> {
|
||||
let len = self.transcript.len_of_direction(Direction::Received);
|
||||
let (_, recv) = self.reveal.get_or_insert_default();
|
||||
recv.union_mut(&(0..len));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<ProveConfig, ProveConfigError> {
|
||||
Ok(ProveConfig {
|
||||
server_identity: self.server_identity,
|
||||
reveal: self.reveal,
|
||||
transcript_commit: self.transcript_commit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to prove statements about the connection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProveRequest {
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitRequest>,
|
||||
}
|
||||
|
||||
impl ProveRequest {
|
||||
/// Returns `true` if the server identity is to be proven.
|
||||
pub fn server_identity(&self) -> bool {
|
||||
self.server_identity
|
||||
}
|
||||
|
||||
/// Returns the sent and received ranges of the transcript to be revealed,
|
||||
/// respectively.
|
||||
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
|
||||
self.reveal.as_ref()
|
||||
}
|
||||
|
||||
/// Returns the transcript commitment configuration.
|
||||
pub fn transcript_commit(&self) -> Option<&TranscriptCommitRequest> {
|
||||
self.transcript_commit.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`ProveConfig`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ProveConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {
|
||||
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
|
||||
IndexOutOfBounds {
|
||||
direction: Direction,
|
||||
actual: usize,
|
||||
len: usize,
|
||||
},
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
//! Prover configuration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Prover configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProverConfig {}
|
||||
|
||||
impl ProverConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> ProverConfigBuilder {
|
||||
ProverConfigBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`ProverConfig`].
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ProverConfigBuilder {}
|
||||
|
||||
impl ProverConfigBuilder {
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<ProverConfig, ProverConfigError> {
|
||||
Ok(ProverConfig {})
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`ProverConfig`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ProverConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {}
|
||||
@@ -1,111 +0,0 @@
|
||||
//! TLS client configuration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
connection::ServerName,
|
||||
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
|
||||
};
|
||||
|
||||
/// TLS client configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsClientConfig {
|
||||
server_name: ServerName,
|
||||
/// Root certificates.
|
||||
root_store: RootCertStore,
|
||||
/// Certificate chain and a matching private key for client
|
||||
/// authentication.
|
||||
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
|
||||
}
|
||||
|
||||
impl TlsClientConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> TlsConfigBuilder {
|
||||
TlsConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the server name.
|
||||
pub fn server_name(&self) -> &ServerName {
|
||||
&self.server_name
|
||||
}
|
||||
|
||||
/// Returns the root certificates.
|
||||
pub fn root_store(&self) -> &RootCertStore {
|
||||
&self.root_store
|
||||
}
|
||||
|
||||
/// Returns a certificate chain and a matching private key for client
|
||||
/// authentication.
|
||||
pub fn client_auth(&self) -> Option<&(Vec<CertificateDer>, PrivateKeyDer)> {
|
||||
self.client_auth.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`TlsClientConfig`].
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TlsConfigBuilder {
|
||||
server_name: Option<ServerName>,
|
||||
root_store: Option<RootCertStore>,
|
||||
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
|
||||
}
|
||||
|
||||
impl TlsConfigBuilder {
|
||||
/// Sets the server name.
|
||||
pub fn server_name(mut self, server_name: ServerName) -> Self {
|
||||
self.server_name = Some(server_name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the root certificates to use for verifying the server's
|
||||
/// certificate.
|
||||
pub fn root_store(mut self, store: RootCertStore) -> Self {
|
||||
self.root_store = Some(store);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a DER-encoded certificate chain and a matching private key for
|
||||
/// client authentication.
|
||||
///
|
||||
/// Often the chain will consist of a single end-entity certificate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cert_key` - A tuple containing the certificate chain and the private
|
||||
/// key.
|
||||
///
|
||||
/// - Each certificate in the chain must be in the X.509 format.
|
||||
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
|
||||
pub fn client_auth(mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> Self {
|
||||
self.client_auth = Some(cert_key);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the TLS configuration.
|
||||
pub fn build(self) -> Result<TlsClientConfig, TlsConfigError> {
|
||||
let server_name = self.server_name.ok_or(ErrorRepr::MissingField {
|
||||
field: "server_name",
|
||||
})?;
|
||||
|
||||
let root_store = self.root_store.ok_or(ErrorRepr::MissingField {
|
||||
field: "root_store",
|
||||
})?;
|
||||
|
||||
Ok(TlsClientConfig {
|
||||
server_name,
|
||||
root_store,
|
||||
client_auth: self.client_auth,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// TLS configuration error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct TlsConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("tls config error")]
|
||||
enum ErrorRepr {
|
||||
#[error("missing required field: {field}")]
|
||||
MissingField { field: &'static str },
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
//! TLS commitment configuration.
|
||||
|
||||
pub mod mpc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// TLS commitment configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsCommitConfig {
|
||||
protocol: TlsCommitProtocolConfig,
|
||||
}
|
||||
|
||||
impl TlsCommitConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> TlsCommitConfigBuilder {
|
||||
TlsCommitConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the protocol configuration.
|
||||
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
|
||||
&self.protocol
|
||||
}
|
||||
|
||||
/// Returns a TLS commitment request.
|
||||
pub fn to_request(&self) -> TlsCommitRequest {
|
||||
TlsCommitRequest {
|
||||
config: self.protocol.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`TlsCommitConfig`].
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsCommitConfigBuilder {
|
||||
protocol: Option<TlsCommitProtocolConfig>,
|
||||
}
|
||||
|
||||
impl TlsCommitConfigBuilder {
|
||||
/// Sets the protocol configuration.
|
||||
pub fn protocol<C>(mut self, protocol: C) -> Self
|
||||
where
|
||||
C: Into<TlsCommitProtocolConfig>,
|
||||
{
|
||||
self.protocol = Some(protocol.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<TlsCommitConfig, TlsCommitConfigError> {
|
||||
let protocol = self
|
||||
.protocol
|
||||
.ok_or(ErrorRepr::MissingField { name: "protocol" })?;
|
||||
|
||||
Ok(TlsCommitConfig { protocol })
|
||||
}
|
||||
}
|
||||
|
||||
/// TLS commitment protocol configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum TlsCommitProtocolConfig {
|
||||
/// MPC-TLS configuration.
|
||||
Mpc(mpc::MpcTlsConfig),
|
||||
}
|
||||
|
||||
impl From<mpc::MpcTlsConfig> for TlsCommitProtocolConfig {
|
||||
fn from(config: mpc::MpcTlsConfig) -> Self {
|
||||
Self::Mpc(config)
|
||||
}
|
||||
}
|
||||
|
||||
/// TLS commitment request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsCommitRequest {
|
||||
config: TlsCommitProtocolConfig,
|
||||
}
|
||||
|
||||
impl TlsCommitRequest {
|
||||
/// Returns the protocol configuration.
|
||||
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`TlsCommitConfig`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct TlsCommitConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {
|
||||
#[error("missing field: {name}")]
|
||||
MissingField { name: &'static str },
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
//! MPC-TLS commitment protocol configuration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Default is 32 bytes to decrypt the TLS protocol messages.
|
||||
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
|
||||
|
||||
/// MPC-TLS commitment protocol configuration.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
#[serde(try_from = "unchecked::MpcTlsConfigUnchecked")]
|
||||
pub struct MpcTlsConfig {
|
||||
/// Maximum number of bytes that can be sent.
|
||||
max_sent_data: usize,
|
||||
/// Maximum number of application data records that can be sent.
|
||||
max_sent_records: Option<usize>,
|
||||
/// Maximum number of bytes that can be decrypted online, i.e. while the
|
||||
/// MPC-TLS connection is active.
|
||||
max_recv_data_online: usize,
|
||||
/// Maximum number of bytes that can be received.
|
||||
max_recv_data: usize,
|
||||
/// Maximum number of received application data records that can be
|
||||
/// decrypted online, i.e. while the MPC-TLS connection is active.
|
||||
max_recv_records_online: Option<usize>,
|
||||
/// Whether the `deferred decryption` feature is toggled on from the start
|
||||
/// of the MPC-TLS connection.
|
||||
defer_decryption_from_start: bool,
|
||||
/// Network settings.
|
||||
network: NetworkSetting,
|
||||
}
|
||||
|
||||
impl MpcTlsConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> MpcTlsConfigBuilder {
|
||||
MpcTlsConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be sent.
|
||||
pub fn max_sent_data(&self) -> usize {
|
||||
self.max_sent_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of application data records that can
|
||||
/// be sent.
|
||||
pub fn max_sent_records(&self) -> Option<usize> {
|
||||
self.max_sent_records
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be decrypted online.
|
||||
pub fn max_recv_data_online(&self) -> usize {
|
||||
self.max_recv_data_online
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be received.
|
||||
pub fn max_recv_data(&self) -> usize {
|
||||
self.max_recv_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of received application data records that
|
||||
/// can be decrypted online.
|
||||
pub fn max_recv_records_online(&self) -> Option<usize> {
|
||||
self.max_recv_records_online
|
||||
}
|
||||
|
||||
/// Returns whether the `deferred decryption` feature is toggled on from the
|
||||
/// start of the MPC-TLS connection.
|
||||
pub fn defer_decryption_from_start(&self) -> bool {
|
||||
self.defer_decryption_from_start
|
||||
}
|
||||
|
||||
/// Returns the network settings.
|
||||
pub fn network(&self) -> NetworkSetting {
|
||||
self.network
|
||||
}
|
||||
}
|
||||
|
||||
fn validate(config: MpcTlsConfig) -> Result<MpcTlsConfig, MpcTlsConfigError> {
|
||||
if config.max_recv_data_online > config.max_recv_data {
|
||||
return Err(ErrorRepr::InvalidValue {
|
||||
name: "max_recv_data_online",
|
||||
reason: format!(
|
||||
"must be <= max_recv_data ({} > {})",
|
||||
config.max_recv_data_online, config.max_recv_data
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Builder for [`MpcTlsConfig`].
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MpcTlsConfigBuilder {
|
||||
max_sent_data: Option<usize>,
|
||||
max_sent_records: Option<usize>,
|
||||
max_recv_data_online: Option<usize>,
|
||||
max_recv_data: Option<usize>,
|
||||
max_recv_records_online: Option<usize>,
|
||||
defer_decryption_from_start: Option<bool>,
|
||||
network: Option<NetworkSetting>,
|
||||
}
|
||||
|
||||
impl MpcTlsConfigBuilder {
|
||||
/// Sets the maximum number of bytes that can be sent.
|
||||
pub fn max_sent_data(mut self, max_sent_data: usize) -> Self {
|
||||
self.max_sent_data = Some(max_sent_data);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the maximum number of application data records that can be sent.
|
||||
pub fn max_sent_records(mut self, max_sent_records: usize) -> Self {
|
||||
self.max_sent_records = Some(max_sent_records);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the maximum number of bytes that can be decrypted online.
|
||||
pub fn max_recv_data_online(mut self, max_recv_data_online: usize) -> Self {
|
||||
self.max_recv_data_online = Some(max_recv_data_online);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the maximum number of bytes that can be received.
|
||||
pub fn max_recv_data(mut self, max_recv_data: usize) -> Self {
|
||||
self.max_recv_data = Some(max_recv_data);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the maximum number of received application data records that can
|
||||
/// be decrypted online.
|
||||
pub fn max_recv_records_online(mut self, max_recv_records_online: usize) -> Self {
|
||||
self.max_recv_records_online = Some(max_recv_records_online);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets whether the `deferred decryption` feature is toggled on from the
|
||||
/// start of the MPC-TLS connection.
|
||||
pub fn defer_decryption_from_start(mut self, defer_decryption_from_start: bool) -> Self {
|
||||
self.defer_decryption_from_start = Some(defer_decryption_from_start);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the network settings.
|
||||
pub fn network(mut self, network: NetworkSetting) -> Self {
|
||||
self.network = Some(network);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<MpcTlsConfig, MpcTlsConfigError> {
|
||||
let Self {
|
||||
max_sent_data,
|
||||
max_sent_records,
|
||||
max_recv_data_online,
|
||||
max_recv_data,
|
||||
max_recv_records_online,
|
||||
defer_decryption_from_start,
|
||||
network,
|
||||
} = self;
|
||||
|
||||
let max_sent_data = max_sent_data.ok_or(ErrorRepr::MissingField {
|
||||
name: "max_sent_data",
|
||||
})?;
|
||||
|
||||
let max_recv_data_online = max_recv_data_online.unwrap_or(DEFAULT_MAX_RECV_ONLINE);
|
||||
let max_recv_data = max_recv_data.ok_or(ErrorRepr::MissingField {
|
||||
name: "max_recv_data",
|
||||
})?;
|
||||
|
||||
let defer_decryption_from_start = defer_decryption_from_start.unwrap_or(true);
|
||||
let network = network.unwrap_or_default();
|
||||
|
||||
validate(MpcTlsConfig {
|
||||
max_sent_data,
|
||||
max_sent_records,
|
||||
max_recv_data_online,
|
||||
max_recv_data,
|
||||
max_recv_records_online,
|
||||
defer_decryption_from_start,
|
||||
network,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings for the network environment.
|
||||
///
|
||||
/// Provides optimization options to adapt the protocol to different network
|
||||
/// situations.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
|
||||
pub enum NetworkSetting {
|
||||
/// Reduces network round-trips at the expense of consuming more network
|
||||
/// bandwidth.
|
||||
Bandwidth,
|
||||
/// Reduces network bandwidth utilization at the expense of more network
|
||||
/// round-trips.
|
||||
#[default]
|
||||
Latency,
|
||||
}
|
||||
|
||||
/// Error for [`MpcTlsConfig`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct MpcTlsConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {
|
||||
#[error("missing field: {name}")]
|
||||
MissingField { name: &'static str },
|
||||
#[error("invalid value for field({name}): {reason}")]
|
||||
InvalidValue { name: &'static str, reason: String },
|
||||
}
|
||||
|
||||
mod unchecked {
|
||||
use super::*;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct MpcTlsConfigUnchecked {
|
||||
max_sent_data: usize,
|
||||
max_sent_records: Option<usize>,
|
||||
max_recv_data_online: usize,
|
||||
max_recv_data: usize,
|
||||
max_recv_records_online: Option<usize>,
|
||||
defer_decryption_from_start: bool,
|
||||
network: NetworkSetting,
|
||||
}
|
||||
|
||||
impl TryFrom<MpcTlsConfigUnchecked> for MpcTlsConfig {
|
||||
type Error = MpcTlsConfigError;
|
||||
|
||||
fn try_from(value: MpcTlsConfigUnchecked) -> Result<Self, Self::Error> {
|
||||
validate(MpcTlsConfig {
|
||||
max_sent_data: value.max_sent_data,
|
||||
max_sent_records: value.max_sent_records,
|
||||
max_recv_data_online: value.max_recv_data_online,
|
||||
max_recv_data: value.max_recv_data,
|
||||
max_recv_records_online: value.max_recv_records_online,
|
||||
defer_decryption_from_start: value.defer_decryption_from_start,
|
||||
network: value.network,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
//! Verifier configuration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::webpki::RootCertStore;
|
||||
|
||||
/// Verifier configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VerifierConfig {
|
||||
root_store: RootCertStore,
|
||||
}
|
||||
|
||||
impl VerifierConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> VerifierConfigBuilder {
|
||||
VerifierConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the root certificate store.
|
||||
pub fn root_store(&self) -> &RootCertStore {
|
||||
&self.root_store
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`VerifierConfig`].
|
||||
#[derive(Debug, Default)]
|
||||
pub struct VerifierConfigBuilder {
|
||||
root_store: Option<RootCertStore>,
|
||||
}
|
||||
|
||||
impl VerifierConfigBuilder {
|
||||
/// Sets the root certificate store.
|
||||
pub fn root_store(mut self, root_store: RootCertStore) -> Self {
|
||||
self.root_store = Some(root_store);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<VerifierConfig, VerifierConfigError> {
|
||||
let root_store = self
|
||||
.root_store
|
||||
.ok_or(ErrorRepr::MissingField { name: "root_store" })?;
|
||||
Ok(VerifierConfig { root_store })
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`VerifierConfig`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct VerifierConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {
|
||||
#[error("missing field: {name}")]
|
||||
MissingField { name: &'static str },
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
use rangeset::set::RangeSet;
|
||||
use rangeset::RangeSet;
|
||||
|
||||
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
|
||||
|
||||
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("{")?;
|
||||
for range in self.0.iter() {
|
||||
for range in self.0.iter_ranges() {
|
||||
write!(f, "{}..{}", range.start, range.end)?;
|
||||
if range.end < self.0.end().unwrap_or(0) {
|
||||
f.write_str(", ")?;
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
|
||||
use aead::Payload as AeadPayload;
|
||||
use aes_gcm::{aead::Aead, Aes128Gcm, NewAead};
|
||||
#[allow(deprecated)]
|
||||
use generic_array::GenericArray;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use tls_core::msgs::{
|
||||
base::Payload,
|
||||
codec::Codec,
|
||||
enums::{HandshakeType, ProtocolVersion},
|
||||
enums::{ContentType, HandshakeType, ProtocolVersion},
|
||||
handshake::{HandshakeMessagePayload, HandshakePayload},
|
||||
message::{OpaqueMessage, PlainMessage},
|
||||
};
|
||||
@@ -16,7 +15,7 @@ use tls_core::msgs::{
|
||||
use crate::{
|
||||
connection::{TranscriptLength, VerifyData},
|
||||
fixtures::ConnectionFixture,
|
||||
transcript::{ContentType, Record, TlsTranscript},
|
||||
transcript::{Record, TlsTranscript},
|
||||
};
|
||||
|
||||
/// The key used for encryption of the sent and received transcript.
|
||||
@@ -104,7 +103,7 @@ impl TranscriptGenerator {
|
||||
|
||||
let explicit_nonce: [u8; 8] = seq.to_be_bytes();
|
||||
let msg = PlainMessage {
|
||||
typ: ContentType::ApplicationData.into(),
|
||||
typ: ContentType::ApplicationData,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(plaintext),
|
||||
};
|
||||
@@ -139,7 +138,7 @@ impl TranscriptGenerator {
|
||||
handshake_message.encode(&mut plaintext);
|
||||
|
||||
let msg = PlainMessage {
|
||||
typ: ContentType::Handshake.into(),
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(plaintext.clone()),
|
||||
};
|
||||
@@ -181,7 +180,6 @@ fn aes_gcm_encrypt(
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(&iv);
|
||||
nonce[4..].copy_from_slice(&explicit_nonce);
|
||||
#[allow(deprecated)]
|
||||
let nonce = GenericArray::from_slice(&nonce);
|
||||
let cipher = Aes128Gcm::new_from_slice(&key).unwrap();
|
||||
|
||||
|
||||
@@ -296,14 +296,14 @@ mod sha2 {
|
||||
fn hash(&self, data: &[u8]) -> super::Hash {
|
||||
let mut hasher = ::sha2::Sha256::default();
|
||||
hasher.update(data);
|
||||
super::Hash::new(hasher.finalize().as_ref())
|
||||
super::Hash::new(hasher.finalize().as_slice())
|
||||
}
|
||||
|
||||
fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
|
||||
let mut hasher = ::sha2::Sha256::default();
|
||||
hasher.update(prefix);
|
||||
hasher.update(data);
|
||||
super::Hash::new(hasher.finalize().as_ref())
|
||||
super::Hash::new(hasher.finalize().as_slice())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,18 +12,212 @@ pub mod merkle;
|
||||
pub mod transcript;
|
||||
pub mod webpki;
|
||||
pub use rangeset;
|
||||
pub mod config;
|
||||
pub(crate) mod display;
|
||||
|
||||
use rangeset::{RangeSet, ToRangeSet, UnionMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
connection::ServerName,
|
||||
connection::{HandshakeData, ServerName},
|
||||
transcript::{
|
||||
encoding::EncoderSecret, PartialTranscript, TranscriptCommitment, TranscriptSecret,
|
||||
encoding::EncoderSecret, Direction, PartialTranscript, Transcript, TranscriptCommitConfig,
|
||||
TranscriptCommitRequest, TranscriptCommitment, TranscriptSecret,
|
||||
},
|
||||
};
|
||||
|
||||
/// Configuration to prove information to the verifier.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProveConfig {
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitConfig>,
|
||||
}
|
||||
|
||||
impl ProveConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
|
||||
ProveConfigBuilder::new(transcript)
|
||||
}
|
||||
|
||||
/// Returns `true` if the server identity is to be proven.
|
||||
pub fn server_identity(&self) -> bool {
|
||||
self.server_identity
|
||||
}
|
||||
|
||||
/// Returns the ranges of the transcript to be revealed.
|
||||
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
|
||||
self.reveal.as_ref()
|
||||
}
|
||||
|
||||
/// Returns the transcript commitment configuration.
|
||||
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
|
||||
self.transcript_commit.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`ProveConfig`].
|
||||
#[derive(Debug)]
|
||||
pub struct ProveConfigBuilder<'a> {
|
||||
transcript: &'a Transcript,
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitConfig>,
|
||||
}
|
||||
|
||||
impl<'a> ProveConfigBuilder<'a> {
|
||||
/// Creates a new builder.
|
||||
pub fn new(transcript: &'a Transcript) -> Self {
|
||||
Self {
|
||||
transcript,
|
||||
server_identity: false,
|
||||
reveal: None,
|
||||
transcript_commit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Proves the server identity.
|
||||
pub fn server_identity(&mut self) -> &mut Self {
|
||||
self.server_identity = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures transcript commitments.
|
||||
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
|
||||
self.transcript_commit = Some(transcript_commit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Reveals the given ranges of the transcript.
|
||||
pub fn reveal(
|
||||
&mut self,
|
||||
direction: Direction,
|
||||
ranges: &dyn ToRangeSet<usize>,
|
||||
) -> Result<&mut Self, ProveConfigBuilderError> {
|
||||
let idx = ranges.to_range_set();
|
||||
|
||||
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
|
||||
return Err(ProveConfigBuilderError(
|
||||
ProveConfigBuilderErrorRepr::IndexOutOfBounds {
|
||||
direction,
|
||||
actual: idx.end().unwrap_or(0),
|
||||
len: self.transcript.len_of_direction(direction),
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let (sent, recv) = self.reveal.get_or_insert_default();
|
||||
match direction {
|
||||
Direction::Sent => sent.union_mut(&idx),
|
||||
Direction::Received => recv.union_mut(&idx),
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Reveals the given ranges of the sent data transcript.
|
||||
pub fn reveal_sent(
|
||||
&mut self,
|
||||
ranges: &dyn ToRangeSet<usize>,
|
||||
) -> Result<&mut Self, ProveConfigBuilderError> {
|
||||
self.reveal(Direction::Sent, ranges)
|
||||
}
|
||||
|
||||
/// Reveals all of the sent data transcript.
|
||||
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
|
||||
let len = self.transcript.len_of_direction(Direction::Sent);
|
||||
let (sent, _) = self.reveal.get_or_insert_default();
|
||||
sent.union_mut(&(0..len));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Reveals the given ranges of the received data transcript.
|
||||
pub fn reveal_recv(
|
||||
&mut self,
|
||||
ranges: &dyn ToRangeSet<usize>,
|
||||
) -> Result<&mut Self, ProveConfigBuilderError> {
|
||||
self.reveal(Direction::Received, ranges)
|
||||
}
|
||||
|
||||
/// Reveals all of the received data transcript.
|
||||
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
|
||||
let len = self.transcript.len_of_direction(Direction::Received);
|
||||
let (_, recv) = self.reveal.get_or_insert_default();
|
||||
recv.union_mut(&(0..len));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<ProveConfig, ProveConfigBuilderError> {
|
||||
Ok(ProveConfig {
|
||||
server_identity: self.server_identity,
|
||||
reveal: self.reveal,
|
||||
transcript_commit: self.transcript_commit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`ProveConfigBuilder`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ProveConfigBuilderError(#[from] ProveConfigBuilderErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ProveConfigBuilderErrorRepr {
|
||||
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
|
||||
IndexOutOfBounds {
|
||||
direction: Direction,
|
||||
actual: usize,
|
||||
len: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// Configuration to verify information from the prover.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct VerifyConfig {}
|
||||
|
||||
impl VerifyConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> VerifyConfigBuilder {
|
||||
VerifyConfigBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`VerifyConfig`].
|
||||
#[derive(Debug, Default)]
|
||||
pub struct VerifyConfigBuilder {}
|
||||
|
||||
impl VerifyConfigBuilder {
|
||||
/// Creates a new builder.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<VerifyConfig, VerifyConfigBuilderError> {
|
||||
Ok(VerifyConfig {})
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`VerifyConfigBuilder`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum VerifyConfigBuilderErrorRepr {}
|
||||
|
||||
/// Request to prove statements about the connection.
|
||||
#[doc(hidden)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProveRequest {
|
||||
/// Handshake data.
|
||||
pub handshake: Option<(ServerName, HandshakeData)>,
|
||||
/// Transcript data.
|
||||
pub transcript: Option<PartialTranscript>,
|
||||
/// Transcript commitment configuration.
|
||||
pub transcript_commit: Option<TranscriptCommitRequest>,
|
||||
}
|
||||
|
||||
/// Prover output.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ProverOutput {
|
||||
|
||||
@@ -26,11 +26,7 @@ mod tls;
|
||||
|
||||
use std::{fmt, ops::Range};
|
||||
|
||||
use rangeset::{
|
||||
iter::RangeIterator,
|
||||
ops::{Index, Set},
|
||||
set::RangeSet,
|
||||
};
|
||||
use rangeset::{Difference, IndexRanges, RangeSet, Union};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::connection::TranscriptLength;
|
||||
@@ -42,7 +38,8 @@ pub use commit::{
|
||||
pub use proof::{
|
||||
TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError,
|
||||
};
|
||||
pub use tls::{ContentType, Record, TlsTranscript};
|
||||
pub use tls::{Record, TlsTranscript};
|
||||
pub use tls_core::msgs::enums::ContentType;
|
||||
|
||||
/// A transcript contains the plaintext of all application data communicated
|
||||
/// between the Prover and the Server.
|
||||
@@ -110,14 +107,8 @@ impl Transcript {
|
||||
}
|
||||
|
||||
Some(
|
||||
Subsequence::new(
|
||||
idx.clone(),
|
||||
data.index(idx).fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
)
|
||||
.expect("data is same length as index"),
|
||||
Subsequence::new(idx.clone(), data.index_ranges(idx))
|
||||
.expect("data is same length as index"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -139,11 +130,11 @@ impl Transcript {
|
||||
let mut sent = vec![0; self.sent.len()];
|
||||
let mut received = vec![0; self.received.len()];
|
||||
|
||||
for range in sent_idx.iter() {
|
||||
for range in sent_idx.iter_ranges() {
|
||||
sent[range.clone()].copy_from_slice(&self.sent[range]);
|
||||
}
|
||||
|
||||
for range in recv_idx.iter() {
|
||||
for range in recv_idx.iter_ranges() {
|
||||
received[range.clone()].copy_from_slice(&self.received[range]);
|
||||
}
|
||||
|
||||
@@ -196,20 +187,12 @@ pub struct CompressedPartialTranscript {
|
||||
impl From<PartialTranscript> for CompressedPartialTranscript {
|
||||
fn from(uncompressed: PartialTranscript) -> Self {
|
||||
Self {
|
||||
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
|
||||
Vec::new(),
|
||||
|mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
},
|
||||
),
|
||||
sent_authed: uncompressed
|
||||
.sent
|
||||
.index_ranges(&uncompressed.sent_authed_idx),
|
||||
received_authed: uncompressed
|
||||
.received
|
||||
.index(&uncompressed.received_authed_idx)
|
||||
.fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
.index_ranges(&uncompressed.received_authed_idx),
|
||||
sent_idx: uncompressed.sent_authed_idx,
|
||||
recv_idx: uncompressed.received_authed_idx,
|
||||
sent_total: uncompressed.sent.len(),
|
||||
@@ -225,7 +208,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.sent_idx.iter() {
|
||||
for range in compressed.sent_idx.iter_ranges() {
|
||||
sent[range.clone()]
|
||||
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
@@ -233,7 +216,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.recv_idx.iter() {
|
||||
for range in compressed.recv_idx.iter_ranges() {
|
||||
received[range.clone()]
|
||||
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
@@ -322,16 +305,12 @@ impl PartialTranscript {
|
||||
|
||||
/// Returns the index of sent data which haven't been authenticated.
|
||||
pub fn sent_unauthed(&self) -> RangeSet<usize> {
|
||||
(0..self.sent.len())
|
||||
.difference(&self.sent_authed_idx)
|
||||
.into_set()
|
||||
(0..self.sent.len()).difference(&self.sent_authed_idx)
|
||||
}
|
||||
|
||||
/// Returns the index of received data which haven't been authenticated.
|
||||
pub fn received_unauthed(&self) -> RangeSet<usize> {
|
||||
(0..self.received.len())
|
||||
.difference(&self.received_authed_idx)
|
||||
.into_set()
|
||||
(0..self.received.len()).difference(&self.received_authed_idx)
|
||||
}
|
||||
|
||||
/// Returns an iterator over the authenticated data in the transcript.
|
||||
@@ -341,7 +320,7 @@ impl PartialTranscript {
|
||||
Direction::Received => (&self.received, &self.received_authed_idx),
|
||||
};
|
||||
|
||||
authed.iter_values().map(move |i| data[i])
|
||||
authed.iter().map(|i| data[i])
|
||||
}
|
||||
|
||||
/// Unions the authenticated data of this transcript with another.
|
||||
@@ -361,20 +340,24 @@ impl PartialTranscript {
|
||||
"received data are not the same length"
|
||||
);
|
||||
|
||||
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
|
||||
for range in other
|
||||
.sent_authed_idx
|
||||
.difference(&self.sent_authed_idx)
|
||||
.iter_ranges()
|
||||
{
|
||||
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
|
||||
}
|
||||
|
||||
for range in other
|
||||
.received_authed_idx
|
||||
.difference(&self.received_authed_idx)
|
||||
.iter_ranges()
|
||||
{
|
||||
self.received[range.clone()].copy_from_slice(&other.received[range]);
|
||||
}
|
||||
|
||||
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
|
||||
self.received_authed_idx
|
||||
.union_mut(&other.received_authed_idx);
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
|
||||
}
|
||||
|
||||
/// Unions an authenticated subsequence into this transcript.
|
||||
@@ -386,11 +369,11 @@ impl PartialTranscript {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
seq.copy_to(&mut self.sent);
|
||||
self.sent_authed_idx.union_mut(&seq.idx);
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
|
||||
}
|
||||
Direction::Received => {
|
||||
seq.copy_to(&mut self.received);
|
||||
self.received_authed_idx.union_mut(&seq.idx);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -401,10 +384,10 @@ impl PartialTranscript {
|
||||
///
|
||||
/// * `value` - The value to set the unauthenticated bytes to
|
||||
pub fn set_unauthed(&mut self, value: u8) {
|
||||
for range in self.sent_unauthed().iter() {
|
||||
for range in self.sent_unauthed().iter_ranges() {
|
||||
self.sent[range].fill(value);
|
||||
}
|
||||
for range in self.received_unauthed().iter() {
|
||||
for range in self.received_unauthed().iter_ranges() {
|
||||
self.received[range].fill(value);
|
||||
}
|
||||
}
|
||||
@@ -419,13 +402,13 @@ impl PartialTranscript {
|
||||
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
for r in range.difference(&self.sent_authed_idx) {
|
||||
self.sent[r].fill(value);
|
||||
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
|
||||
self.sent[range].fill(value);
|
||||
}
|
||||
}
|
||||
Direction::Received => {
|
||||
for r in range.difference(&self.received_authed_idx) {
|
||||
self.received[r].fill(value);
|
||||
for range in range.difference(&self.received_authed_idx).iter_ranges() {
|
||||
self.received[range].fill(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -503,7 +486,7 @@ impl Subsequence {
|
||||
/// Panics if the subsequence ranges are out of bounds.
|
||||
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
|
||||
let mut offset = 0;
|
||||
for range in self.idx.iter() {
|
||||
for range in self.idx.iter_ranges() {
|
||||
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
}
|
||||
@@ -628,7 +611,12 @@ mod validation {
|
||||
mut partial_transcript: CompressedPartialTranscriptUnchecked,
|
||||
) {
|
||||
// Change the total to be less than the last range's end bound.
|
||||
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
|
||||
let end = partial_transcript
|
||||
.sent_idx
|
||||
.iter_ranges()
|
||||
.next_back()
|
||||
.unwrap()
|
||||
.end;
|
||||
|
||||
partial_transcript.sent_total = end - 1;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use std::{collections::HashSet, fmt};
|
||||
|
||||
use rangeset::set::ToRangeSet;
|
||||
use rangeset::{ToRangeSet, UnionMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
use rangeset::set::RangeSet;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -103,7 +103,7 @@ impl EncodingProof {
|
||||
}
|
||||
|
||||
expected_leaf.clear();
|
||||
for range in idx.iter() {
|
||||
for range in idx.iter_ranges() {
|
||||
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
|
||||
}
|
||||
expected_leaf.extend_from_slice(blinder.as_bytes());
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use bimap::BiMap;
|
||||
use rangeset::set::RangeSet;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -99,7 +99,7 @@ impl EncodingTree {
|
||||
let blinder: Blinder = rand::random();
|
||||
|
||||
encoding.clear();
|
||||
for range in idx.iter() {
|
||||
for range in idx.iter_ranges() {
|
||||
provider
|
||||
.provide_encoding(direction, range, &mut encoding)
|
||||
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
//! Transcript proofs.
|
||||
|
||||
use rangeset::{
|
||||
iter::RangeIterator,
|
||||
ops::{Cover, Set},
|
||||
set::ToRangeSet,
|
||||
};
|
||||
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashSet, fmt};
|
||||
|
||||
@@ -29,9 +25,6 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::BLAKE3,
|
||||
},
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::KECCAK256,
|
||||
},
|
||||
TranscriptCommitmentKind::Encoding,
|
||||
];
|
||||
|
||||
@@ -148,7 +141,7 @@ impl TranscriptProof {
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
for range in idx.iter() {
|
||||
for range in idx.iter_ranges() {
|
||||
buffer.extend_from_slice(&plaintext[range]);
|
||||
}
|
||||
|
||||
@@ -370,7 +363,7 @@ impl<'a> TranscriptProofBuilder<'a> {
|
||||
if idx.is_subset(committed) {
|
||||
self.query_idx.union(&direction, &idx);
|
||||
} else {
|
||||
let missing = idx.difference(committed).into_set();
|
||||
let missing = idx.difference(committed);
|
||||
return Err(TranscriptProofBuilderError::new(
|
||||
BuilderErrorKind::MissingCommitment,
|
||||
format!(
|
||||
@@ -586,7 +579,7 @@ impl fmt::Display for TranscriptProofBuilderError {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rangeset::prelude::*;
|
||||
use rangeset::RangeSet;
|
||||
use rstest::rstest;
|
||||
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
|
||||
@@ -663,7 +656,6 @@ mod tests {
|
||||
#[rstest]
|
||||
#[case::sha256(HashAlgId::SHA256)]
|
||||
#[case::blake3(HashAlgId::BLAKE3)]
|
||||
#[case::keccak256(HashAlgId::KECCAK256)]
|
||||
fn test_reveal_with_hash_commitment(#[case] alg: HashAlgId) {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
let provider = HashProvider::default();
|
||||
@@ -712,7 +704,6 @@ mod tests {
|
||||
#[rstest]
|
||||
#[case::sha256(HashAlgId::SHA256)]
|
||||
#[case::blake3(HashAlgId::BLAKE3)]
|
||||
#[case::keccak256(HashAlgId::KECCAK256)]
|
||||
fn test_reveal_with_inconsistent_hash_commitment(#[case] alg: HashAlgId) {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
let provider = HashProvider::default();
|
||||
|
||||
@@ -10,53 +10,10 @@ use crate::{
|
||||
use tls_core::msgs::{
|
||||
alert::AlertMessagePayload,
|
||||
codec::{Codec, Reader},
|
||||
enums::{AlertDescription, ProtocolVersion},
|
||||
enums::{AlertDescription, ContentType, ProtocolVersion},
|
||||
handshake::{HandshakeMessagePayload, HandshakePayload},
|
||||
};
|
||||
|
||||
/// TLS record content type.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub enum ContentType {
|
||||
/// Change cipher spec protocol.
|
||||
ChangeCipherSpec,
|
||||
/// Alert protocol.
|
||||
Alert,
|
||||
/// Handshake protocol.
|
||||
Handshake,
|
||||
/// Application data protocol.
|
||||
ApplicationData,
|
||||
/// Heartbeat protocol.
|
||||
Heartbeat,
|
||||
/// Unknown protocol.
|
||||
Unknown(u8),
|
||||
}
|
||||
|
||||
impl From<ContentType> for tls_core::msgs::enums::ContentType {
|
||||
fn from(content_type: ContentType) -> Self {
|
||||
match content_type {
|
||||
ContentType::ChangeCipherSpec => tls_core::msgs::enums::ContentType::ChangeCipherSpec,
|
||||
ContentType::Alert => tls_core::msgs::enums::ContentType::Alert,
|
||||
ContentType::Handshake => tls_core::msgs::enums::ContentType::Handshake,
|
||||
ContentType::ApplicationData => tls_core::msgs::enums::ContentType::ApplicationData,
|
||||
ContentType::Heartbeat => tls_core::msgs::enums::ContentType::Heartbeat,
|
||||
ContentType::Unknown(id) => tls_core::msgs::enums::ContentType::Unknown(id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tls_core::msgs::enums::ContentType> for ContentType {
|
||||
fn from(content_type: tls_core::msgs::enums::ContentType) -> Self {
|
||||
match content_type {
|
||||
tls_core::msgs::enums::ContentType::ChangeCipherSpec => ContentType::ChangeCipherSpec,
|
||||
tls_core::msgs::enums::ContentType::Alert => ContentType::Alert,
|
||||
tls_core::msgs::enums::ContentType::Handshake => ContentType::Handshake,
|
||||
tls_core::msgs::enums::ContentType::ApplicationData => ContentType::ApplicationData,
|
||||
tls_core::msgs::enums::ContentType::Heartbeat => ContentType::Heartbeat,
|
||||
tls_core::msgs::enums::ContentType::Unknown(id) => ContentType::Unknown(id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A transcript of TLS records sent and received by the prover.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TlsTranscript {
|
||||
|
||||
@@ -53,21 +53,6 @@ impl RootCertStore {
|
||||
pub fn empty() -> Self {
|
||||
Self { roots: Vec::new() }
|
||||
}
|
||||
|
||||
/// Creates a root certificate store with Mozilla root certificates.
|
||||
///
|
||||
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
|
||||
/// application binary can be recompiled and deployed on-demand in the case
|
||||
/// that the root certificates need to be updated.
|
||||
#[cfg(feature = "mozilla-certs")]
|
||||
pub fn mozilla() -> Self {
|
||||
Self {
|
||||
roots: webpki_root_certs::TLS_SERVER_ROOT_CERTS
|
||||
.iter()
|
||||
.map(|cert| CertificateDer(cert.to_vec()))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server certificate verifier.
|
||||
@@ -97,12 +82,8 @@ impl ServerCertVerifier {
|
||||
Ok(Self { roots })
|
||||
}
|
||||
|
||||
/// Creates a server certificate verifier with Mozilla root certificates.
|
||||
///
|
||||
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
|
||||
/// application binary can be recompiled and deployed on-demand in the case
|
||||
/// that the root certificates need to be updated.
|
||||
#[cfg(feature = "mozilla-certs")]
|
||||
/// Creates a new server certificate verifier with Mozilla root
|
||||
/// certificates.
|
||||
pub fn mozilla() -> Self {
|
||||
Self {
|
||||
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
|
||||
|
||||
@@ -15,7 +15,6 @@ tlsn-server-fixture = { workspace = true }
|
||||
tlsn-server-fixture-certs = { workspace = true }
|
||||
spansy = { workspace = true }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
@@ -38,9 +37,7 @@ tokio = { workspace = true, features = [
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = [
|
||||
"barretenberg",
|
||||
] }
|
||||
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = ["barretenberg"] }
|
||||
|
||||
[[example]]
|
||||
name = "interactive"
|
||||
|
||||
@@ -4,7 +4,5 @@ This folder contains examples demonstrating how to use the TLSNotary protocol.
|
||||
|
||||
* [Interactive](./interactive/README.md): Interactive Prover and Verifier session without a trusted notary.
|
||||
* [Attestation](./attestation/README.md): Performing a simple notarization with a trusted notary.
|
||||
* [Interactive_zk](./interactive_zk/README.md): Interactive Prover and Verifier session demonstrating zero-knowledge age verification using Noir.
|
||||
|
||||
|
||||
Refer to <https://tlsnotary.org/docs/quick_start> for a quick start guide to using TLSNotary with these examples.
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
use std::env;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use clap::Parser;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{body::Bytes, Request, StatusCode};
|
||||
@@ -24,17 +23,12 @@ use tlsn::{
|
||||
Attestation, AttestationConfig, CryptoProvider, Secrets,
|
||||
},
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
prover::ProverConfig,
|
||||
tls::TlsClientConfig,
|
||||
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
|
||||
verifier::VerifierConfig,
|
||||
CertificateDer, PrivateKeyDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore,
|
||||
},
|
||||
connection::{ConnectionInfo, HandshakeData, ServerName, TranscriptLength},
|
||||
prover::{state::Committed, Prover, ProverOutput},
|
||||
prover::{state::Committed, ProveConfig, Prover, ProverConfig, ProverOutput, TlsConfig},
|
||||
transcript::{ContentType, TranscriptCommitConfig},
|
||||
verifier::{Verifier, VerifierOutput},
|
||||
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
|
||||
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
|
||||
};
|
||||
use tlsn_examples::ExampleType;
|
||||
use tlsn_formats::http::{DefaultHttpCommitter, HttpCommit, HttpTranscript};
|
||||
@@ -53,7 +47,7 @@ struct Args {
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let args = Args::parse();
|
||||
@@ -87,70 +81,70 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
verifier_socket: S,
|
||||
socket: S,
|
||||
req_tx: Sender<AttestationRequest>,
|
||||
resp_rx: Receiver<Attestation>,
|
||||
uri: &str,
|
||||
extra_headers: Vec<(&str, &str)>,
|
||||
example_type: &ExampleType,
|
||||
) -> Result<()> {
|
||||
let mut verifier_socket = verifier_socket.compat();
|
||||
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
|
||||
let server_port: u16 = env::var("SERVER_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(DEFAULT_FIXTURE_PORT);
|
||||
|
||||
// Create a new prover and perform necessary setup.
|
||||
let prover = Prover::new(ProverConfig::builder().build()?)
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
// Select the TLS commitment protocol.
|
||||
.protocol(
|
||||
MpcTlsConfig::builder()
|
||||
// We must configure the amount of data we expect to exchange beforehand,
|
||||
// which will be preprocessed prior to the
|
||||
// connection. Reducing these limits will improve
|
||||
// performance.
|
||||
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
|
||||
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
|
||||
.build()?,
|
||||
)
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
let mut tls_config_builder = TlsConfig::builder();
|
||||
tls_config_builder
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
// (Optional) Set up TLS client authentication if required by the server.
|
||||
.client_auth((
|
||||
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
|
||||
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
|
||||
));
|
||||
|
||||
let tls_config = tls_config_builder.build().unwrap();
|
||||
|
||||
// Set up protocol configuration for prover.
|
||||
let mut prover_config_builder = ProverConfig::builder();
|
||||
prover_config_builder
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
|
||||
.tls_config(tls_config)
|
||||
.protocol_config(
|
||||
ProtocolConfig::builder()
|
||||
// We must configure the amount of data we expect to exchange beforehand, which will
|
||||
// be preprocessed prior to the connection. Reducing these limits will improve
|
||||
// performance.
|
||||
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
|
||||
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
|
||||
.build()?,
|
||||
&mut verifier_socket,
|
||||
)
|
||||
.await?;
|
||||
);
|
||||
|
||||
let prover_config = prover_config_builder.build()?;
|
||||
|
||||
// Create a new prover and perform necessary setup.
|
||||
let prover = Prover::new(prover_config).setup(socket.compat()).await?;
|
||||
|
||||
// Open a TCP connection to the server.
|
||||
let client_socket = tokio::net::TcpStream::connect((server_host, server_port))
|
||||
.await?
|
||||
.compat();
|
||||
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?;
|
||||
|
||||
// Bind the prover to the server connection.
|
||||
let (tls_connection, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
// (Optional) Set up TLS client authentication if required by the server.
|
||||
.client_auth((
|
||||
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
|
||||
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
|
||||
))
|
||||
.build()?,
|
||||
)?;
|
||||
let tls_connection = TokioIo::new(tls_connection.compat());
|
||||
// The returned `mpc_tls_connection` is an MPC TLS connection to the server: all
|
||||
// data written to/read from it will be encrypted/decrypted using MPC with
|
||||
// the notary.
|
||||
let (mpc_tls_connection, prover_fut) = prover.connect(client_socket.compat()).await?;
|
||||
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
|
||||
|
||||
// Spawn the prover task to be run concurrently in the background.
|
||||
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
|
||||
// Attach the hyper HTTP client to the connection.
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(tls_connection).await?;
|
||||
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
|
||||
|
||||
// Spawn the HTTP task to be run concurrently in the background.
|
||||
tokio::spawn(connection);
|
||||
@@ -171,7 +165,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
}
|
||||
let request = request_builder.body(Empty::<Bytes>::new())?;
|
||||
|
||||
info!("Starting connection with the server");
|
||||
info!("Starting an MPC TLS connection with the server");
|
||||
|
||||
// Send the request to the server and wait for the response.
|
||||
let response = request_sender.send_request(request).await?;
|
||||
@@ -181,7 +175,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// The prover task should be done now, so we can await it.
|
||||
let (prover, _, verifier_socket) = prover_task.await??;
|
||||
let prover = prover_task.await??;
|
||||
|
||||
// Parse the HTTP transcript.
|
||||
let transcript = HttpTranscript::parse(prover.transcript())?;
|
||||
@@ -223,8 +217,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
|
||||
let request_config = builder.build()?;
|
||||
|
||||
let (attestation, secrets) =
|
||||
notarize(prover, &request_config, verifier_socket, req_tx, resp_rx).await?;
|
||||
let (attestation, secrets) = notarize(prover, &request_config, req_tx, resp_rx).await?;
|
||||
|
||||
// Write the attestation to disk.
|
||||
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
|
||||
@@ -244,13 +237,12 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
|
||||
async fn notarize(
|
||||
mut prover: Prover<Committed>,
|
||||
config: &RequestConfig,
|
||||
mut verifier_socket: S,
|
||||
request_tx: Sender<AttestationRequest>,
|
||||
attestation_rx: Receiver<Attestation>,
|
||||
) -> Result<(Attestation, Secrets)> {
|
||||
) -> Result<(Attestation, Secrets), Box<dyn std::error::Error>> {
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
|
||||
if let Some(config) = config.transcript_commit() {
|
||||
@@ -263,13 +255,11 @@ async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
|
||||
transcript_commitments,
|
||||
transcript_secrets,
|
||||
..
|
||||
} = prover
|
||||
.prove(&disclosure_config, &mut verifier_socket)
|
||||
.await?;
|
||||
} = prover.prove(&disclosure_config).await?;
|
||||
|
||||
let transcript = prover.transcript().clone();
|
||||
let tls_transcript = prover.tls_transcript().clone();
|
||||
prover.close(&mut verifier_socket).await?;
|
||||
prover.close().await?;
|
||||
|
||||
// Build an attestation request.
|
||||
let mut builder = AttestationRequest::builder(config);
|
||||
@@ -295,28 +285,30 @@ async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
|
||||
// Send attestation request to notary.
|
||||
request_tx
|
||||
.send(request.clone())
|
||||
.map_err(|_| anyhow!("notary is not receiving attestation request"))?;
|
||||
.map_err(|_| "notary is not receiving attestation request".to_string())?;
|
||||
|
||||
// Receive attestation from notary.
|
||||
let attestation = attestation_rx
|
||||
.await
|
||||
.map_err(|err| anyhow!("notary did not respond with attestation: {err}"))?;
|
||||
|
||||
// Signature verifier for the signature algorithm in the request.
|
||||
let provider = CryptoProvider::default();
|
||||
.map_err(|err| format!("notary did not respond with attestation: {err}"))?;
|
||||
|
||||
// Check the attestation is consistent with the Prover's view.
|
||||
request.validate(&attestation, &provider)?;
|
||||
request.validate(&attestation)?;
|
||||
|
||||
Ok((attestation, secrets))
|
||||
}
|
||||
|
||||
async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
prover_socket: S,
|
||||
socket: S,
|
||||
request_rx: Receiver<AttestationRequest>,
|
||||
attestation_tx: Sender<Attestation>,
|
||||
) -> Result<()> {
|
||||
let mut prover_socket = prover_socket.compat();
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Set up Verifier.
|
||||
let config_validator = ProtocolConfigValidator::builder()
|
||||
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
|
||||
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
@@ -325,33 +317,25 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.protocol_config_validator(config_validator)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let verifier = Verifier::new(verifier_config)
|
||||
.commit(&mut prover_socket)
|
||||
let mut verifier = Verifier::new(verifier_config)
|
||||
.setup(socket.compat())
|
||||
.await?
|
||||
.accept(&mut prover_socket)
|
||||
.await?
|
||||
.run(&mut prover_socket)
|
||||
.run()
|
||||
.await?;
|
||||
|
||||
let (
|
||||
VerifierOutput {
|
||||
transcript_commitments,
|
||||
encoder_secret,
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier
|
||||
.verify(&mut prover_socket)
|
||||
.await?
|
||||
.accept(&mut prover_socket)
|
||||
.await?;
|
||||
let VerifierOutput {
|
||||
transcript_commitments,
|
||||
encoder_secret,
|
||||
..
|
||||
} = verifier.verify(&VerifyConfig::default()).await?;
|
||||
|
||||
let tls_transcript = verifier.tls_transcript().clone();
|
||||
|
||||
verifier.close(&mut prover_socket).await?;
|
||||
verifier.close().await?;
|
||||
|
||||
let sent_len = tls_transcript
|
||||
.sent()
|
||||
@@ -413,7 +397,7 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
// Send attestation to prover.
|
||||
attestation_tx
|
||||
.send(attestation)
|
||||
.map_err(|_| anyhow!("prover is not receiving attestation"))?;
|
||||
.map_err(|_| "prover is not receiving attestation".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ use tlsn::{
|
||||
signing::VerifyingKey,
|
||||
CryptoProvider,
|
||||
},
|
||||
config::{CertificateDer, RootCertStore},
|
||||
verifier::ServerCertVerifier,
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
};
|
||||
use tlsn_examples::ExampleType;
|
||||
use tlsn_server_fixture_certs::CA_CERT_DER;
|
||||
|
||||
@@ -3,7 +3,6 @@ use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{body::Bytes, Request, StatusCode, Uri};
|
||||
use hyper_util::rt::TokioIo;
|
||||
@@ -12,18 +11,11 @@ use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use tracing::instrument;
|
||||
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
prover::ProverConfig,
|
||||
tls::TlsClientConfig,
|
||||
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig, TlsCommitProtocolConfig},
|
||||
verifier::VerifierConfig,
|
||||
},
|
||||
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
|
||||
connection::ServerName,
|
||||
prover::Prover,
|
||||
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
|
||||
transcript::PartialTranscript,
|
||||
verifier::{Verifier, VerifierOutput},
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
|
||||
};
|
||||
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
|
||||
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
|
||||
@@ -54,7 +46,7 @@ async fn main() {
|
||||
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
|
||||
let prover = prover(prover_socket, &server_addr, &uri);
|
||||
let verifier = verifier(verifier_socket);
|
||||
let (_, transcript) = tokio::try_join!(prover, verifier).unwrap();
|
||||
let (_, transcript) = tokio::join!(prover, verifier);
|
||||
|
||||
println!("Successfully verified {}", &uri);
|
||||
println!(
|
||||
@@ -72,57 +64,61 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
verifier_socket: T,
|
||||
server_addr: &SocketAddr,
|
||||
uri: &str,
|
||||
) -> Result<()> {
|
||||
let mut verifier_socket = verifier_socket.compat();
|
||||
|
||||
) {
|
||||
let uri = uri.parse::<Uri>().unwrap();
|
||||
assert_eq!(uri.scheme().unwrap().as_str(), "https");
|
||||
let server_domain = uri.authority().unwrap().host();
|
||||
|
||||
// Create a new prover and perform necessary setup.
|
||||
let prover = Prover::new(ProverConfig::builder().build()?)
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
// Select the TLS commitment protocol.
|
||||
.protocol(
|
||||
MpcTlsConfig::builder()
|
||||
// We must configure the amount of data we expect to exchange beforehand,
|
||||
// which will be preprocessed prior to the
|
||||
// connection. Reducing these limits will improve
|
||||
// performance.
|
||||
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
|
||||
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
|
||||
.build()?,
|
||||
)
|
||||
.build()?,
|
||||
&mut verifier_socket,
|
||||
)
|
||||
.await?;
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
let mut tls_config_builder = TlsConfig::builder();
|
||||
tls_config_builder.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
});
|
||||
let tls_config = tls_config_builder.build().unwrap();
|
||||
|
||||
// Open a TCP connection to the server.
|
||||
let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
|
||||
// Set up protocol configuration for prover.
|
||||
let mut prover_config_builder = ProverConfig::builder();
|
||||
prover_config_builder
|
||||
.server_name(ServerName::Dns(server_domain.try_into().unwrap()))
|
||||
.tls_config(tls_config)
|
||||
.protocol_config(
|
||||
ProtocolConfig::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Bind the prover to the server connection.
|
||||
let (tls_connection, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
)?;
|
||||
let prover_config = prover_config_builder.build().unwrap();
|
||||
|
||||
let tls_connection = TokioIo::new(tls_connection.compat());
|
||||
// Create prover and connect to verifier.
|
||||
//
|
||||
// Perform the setup phase with the verifier.
|
||||
let prover = Prover::new(prover_config)
|
||||
.setup(verifier_socket.compat())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Connect to TLS Server.
|
||||
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await.unwrap();
|
||||
|
||||
// Pass server connection into the prover.
|
||||
let (mpc_tls_connection, prover_fut) =
|
||||
prover.connect(tls_client_socket.compat()).await.unwrap();
|
||||
|
||||
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
|
||||
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
|
||||
|
||||
// Spawn the Prover to run in the background.
|
||||
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
|
||||
// MPC-TLS Handshake.
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(tls_connection).await?;
|
||||
hyper::client::conn::http1::handshake(mpc_tls_connection)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Spawn the connection to run in the background.
|
||||
tokio::spawn(connection);
|
||||
@@ -134,13 +130,14 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
.header("Connection", "close")
|
||||
.header("Secret", SECRET)
|
||||
.method("GET")
|
||||
.body(Empty::<Bytes>::new())?;
|
||||
let response = request_sender.send_request(request).await?;
|
||||
.body(Empty::<Bytes>::new())
|
||||
.unwrap();
|
||||
let response = request_sender.send_request(request).await.unwrap();
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// Create proof for the Verifier.
|
||||
let (mut prover, _, mut verifier_socket) = prover_task.await??;
|
||||
let mut prover = prover_task.await.unwrap().unwrap();
|
||||
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
|
||||
@@ -156,8 +153,10 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
.expect("the secret should be in the sent data");
|
||||
|
||||
// Reveal everything except for the secret.
|
||||
builder.reveal_sent(&(0..pos))?;
|
||||
builder.reveal_sent(&(pos + SECRET.len()..prover.transcript().sent().len()))?;
|
||||
builder.reveal_sent(&(0..pos)).unwrap();
|
||||
builder
|
||||
.reveal_sent(&(pos + SECRET.len()..prover.transcript().sent().len()))
|
||||
.unwrap();
|
||||
|
||||
// Find the substring "Dick".
|
||||
let pos = prover
|
||||
@@ -168,22 +167,27 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
.expect("the substring 'Dick' should be in the received data");
|
||||
|
||||
// Reveal everything except for the substring.
|
||||
builder.reveal_recv(&(0..pos))?;
|
||||
builder.reveal_recv(&(pos + 4..prover.transcript().received().len()))?;
|
||||
builder.reveal_recv(&(0..pos)).unwrap();
|
||||
builder
|
||||
.reveal_recv(&(pos + 4..prover.transcript().received().len()))
|
||||
.unwrap();
|
||||
|
||||
let config = builder.build()?;
|
||||
let config = builder.build().unwrap();
|
||||
|
||||
prover.prove(&config, &mut verifier_socket).await?;
|
||||
prover.close(&mut verifier_socket).await?;
|
||||
|
||||
Ok(())
|
||||
prover.prove(&config).await.unwrap();
|
||||
prover.close().await.unwrap();
|
||||
}
|
||||
|
||||
#[instrument(skip(socket))]
|
||||
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: T,
|
||||
) -> Result<PartialTranscript> {
|
||||
let mut socket = socket.compat();
|
||||
) -> PartialTranscript {
|
||||
// Set up Verifier.
|
||||
let config_validator = ProtocolConfigValidator::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
@@ -192,56 +196,20 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?;
|
||||
.protocol_config_validator(config_validator)
|
||||
.build()
|
||||
.unwrap();
|
||||
let verifier = Verifier::new(verifier_config);
|
||||
|
||||
// Validate the proposed configuration and then run the TLS commitment protocol.
|
||||
let verifier = verifier.commit(&mut socket).await?;
|
||||
|
||||
// This is the opportunity to ensure the prover does not attempt to overload the
|
||||
// verifier.
|
||||
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
|
||||
{
|
||||
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
|
||||
Some("max_sent_data is too large")
|
||||
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
|
||||
Some("max_recv_data is too large")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
Some("expecting to use MPC-TLS")
|
||||
};
|
||||
|
||||
if reject.is_some() {
|
||||
verifier.reject(&mut socket, reject).await?;
|
||||
return Err(anyhow::anyhow!("protocol configuration rejected"));
|
||||
}
|
||||
|
||||
// Runs the TLS commitment protocol to completion.
|
||||
let verifier = verifier.accept(&mut socket).await?.run(&mut socket).await?;
|
||||
|
||||
// Validate the proving request and then verify.
|
||||
let verifier = verifier.verify(&mut socket).await?;
|
||||
|
||||
if !verifier.request().server_identity() {
|
||||
let verifier = verifier
|
||||
.reject(&mut socket, Some("expecting to verify the server name"))
|
||||
.await?;
|
||||
verifier.close(&mut socket).await?;
|
||||
return Err(anyhow::anyhow!("prover did not reveal the server name"));
|
||||
}
|
||||
|
||||
let (
|
||||
VerifierOutput {
|
||||
server_name,
|
||||
transcript,
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier.accept(&mut socket).await?;
|
||||
|
||||
verifier.close(&mut socket).await?;
|
||||
// Receive authenticated data.
|
||||
let VerifierOutput {
|
||||
server_name,
|
||||
transcript,
|
||||
..
|
||||
} = verifier
|
||||
.verify(socket.compat(), &VerifyConfig::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = server_name.expect("prover should have revealed server name");
|
||||
let transcript = transcript.expect("prover should have revealed transcript data");
|
||||
@@ -264,7 +232,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
let ServerName::Dns(server_name) = server_name;
|
||||
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
|
||||
|
||||
Ok(transcript)
|
||||
transcript
|
||||
}
|
||||
|
||||
/// Render redacted bytes as `🙈`.
|
||||
|
||||
@@ -2,7 +2,6 @@ mod prover;
|
||||
mod types;
|
||||
mod verifier;
|
||||
|
||||
use anyhow::Result;
|
||||
use prover::prover;
|
||||
use std::{
|
||||
env,
|
||||
@@ -13,7 +12,7 @@ use tlsn_server_fixture_certs::SERVER_DOMAIN;
|
||||
use verifier::verifier;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
|
||||
@@ -26,15 +25,16 @@ async fn main() -> Result<()> {
|
||||
let uri = format!("https://{SERVER_DOMAIN}:{server_port}/elster");
|
||||
let server_ip: IpAddr = server_host
|
||||
.parse()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid IP address '{server_host}': {e}"))?;
|
||||
.map_err(|e| format!("Invalid IP address '{}': {}", server_host, e))?;
|
||||
let server_addr = SocketAddr::from((server_ip, server_port));
|
||||
|
||||
// Connect prover and verifier.
|
||||
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
|
||||
let (prover_extra_socket, verifier_extra_socket) = tokio::io::duplex(1 << 23);
|
||||
|
||||
let (_, transcript) = tokio::try_join!(
|
||||
prover(prover_socket, &server_addr, &uri),
|
||||
verifier(verifier_socket)
|
||||
prover(prover_socket, prover_extra_socket, &server_addr, &uri),
|
||||
verifier(verifier_socket, verifier_extra_socket)
|
||||
)?;
|
||||
|
||||
println!("---");
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::types::received_commitments;
|
||||
|
||||
use super::types::ZKProofBundle;
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::{Datelike, Local, NaiveDate};
|
||||
use http_body_util::Empty;
|
||||
use hyper::{body::Bytes, header, Request, StatusCode, Uri};
|
||||
@@ -22,91 +21,87 @@ use spansy::{
|
||||
http::{BodyContent, Requests, Responses},
|
||||
Spanned,
|
||||
};
|
||||
use tls_server_fixture::{CA_CERT_DER, SERVER_DOMAIN};
|
||||
use tls_server_fixture::CA_CERT_DER;
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::{ProveConfig, ProveConfigBuilder},
|
||||
prover::ProverConfig,
|
||||
tls::TlsClientConfig,
|
||||
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
|
||||
},
|
||||
config::{CertificateDer, ProtocolConfig, RootCertStore},
|
||||
connection::ServerName,
|
||||
hash::HashAlgId,
|
||||
prover::Prover,
|
||||
prover::{ProveConfig, ProveConfigBuilder, Prover, ProverConfig, TlsConfig},
|
||||
transcript::{
|
||||
hash::{PlaintextHash, PlaintextHashSecret},
|
||||
Direction, TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitmentKind,
|
||||
TranscriptSecret,
|
||||
},
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
};
|
||||
|
||||
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tlsn_examples::MAX_RECV_DATA;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use tlsn_examples::MAX_SENT_DATA;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use tracing::instrument;
|
||||
|
||||
#[instrument(skip(verifier_socket))]
|
||||
#[instrument(skip(verifier_socket, verifier_extra_socket))]
|
||||
pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
verifier_socket: T,
|
||||
mut verifier_extra_socket: T,
|
||||
server_addr: &SocketAddr,
|
||||
uri: &str,
|
||||
) -> Result<()> {
|
||||
let mut verifier_socket = verifier_socket.compat();
|
||||
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let uri = uri.parse::<Uri>()?;
|
||||
|
||||
if uri.scheme().map(|s| s.as_str()) != Some("https") {
|
||||
return Err(anyhow::anyhow!("URI must use HTTPS scheme"));
|
||||
return Err("URI must use HTTPS scheme".into());
|
||||
}
|
||||
|
||||
let server_domain = uri
|
||||
.authority()
|
||||
.ok_or_else(|| anyhow::anyhow!("URI must have authority"))?
|
||||
.host();
|
||||
let server_domain = uri.authority().ok_or("URI must have authority")?.host();
|
||||
|
||||
// Create a new prover and perform necessary setup.
|
||||
let prover = Prover::new(ProverConfig::builder().build()?)
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
// Select the TLS commitment protocol.
|
||||
.protocol(
|
||||
MpcTlsConfig::builder()
|
||||
// We must configure the amount of data we expect to exchange beforehand,
|
||||
// which will be preprocessed prior to the
|
||||
// connection. Reducing these limits will improve
|
||||
// performance.
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.build()?,
|
||||
)
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
let mut tls_config_builder = TlsConfig::builder();
|
||||
tls_config_builder.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
});
|
||||
let tls_config = tls_config_builder.build()?;
|
||||
|
||||
// Set up protocol configuration for prover.
|
||||
let mut prover_config_builder = ProverConfig::builder();
|
||||
prover_config_builder
|
||||
.server_name(ServerName::Dns(server_domain.try_into()?))
|
||||
.tls_config(tls_config)
|
||||
.protocol_config(
|
||||
ProtocolConfig::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.build()?,
|
||||
&mut verifier_socket,
|
||||
)
|
||||
);
|
||||
|
||||
let prover_config = prover_config_builder.build()?;
|
||||
|
||||
// Create prover and connect to verifier.
|
||||
//
|
||||
// Perform the setup phase with the verifier.
|
||||
let prover = Prover::new(prover_config)
|
||||
.setup(verifier_socket.compat())
|
||||
.await?;
|
||||
|
||||
// Open a TCP connection to the server.
|
||||
let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
|
||||
// Connect to TLS Server.
|
||||
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await?;
|
||||
|
||||
// Bind the prover to the server connection.
|
||||
let (tls_connection, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
)?;
|
||||
let tls_connection = TokioIo::new(tls_connection.compat());
|
||||
// Pass server connection into the prover.
|
||||
let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await?;
|
||||
|
||||
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
|
||||
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
|
||||
|
||||
// Spawn the Prover to run in the background.
|
||||
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
|
||||
// MPC-TLS Handshake.
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(tls_connection).await?;
|
||||
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
|
||||
|
||||
// Spawn the connection to run in the background.
|
||||
tokio::spawn(connection);
|
||||
@@ -123,14 +118,11 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
let response = request_sender.send_request(request).await?;
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
return Err(anyhow::anyhow!(
|
||||
"MPC-TLS request failed with status {}",
|
||||
response.status()
|
||||
));
|
||||
return Err(format!("MPC-TLS request failed with status {}", response.status()).into());
|
||||
}
|
||||
|
||||
// Create proof for the Verifier.
|
||||
let (mut prover, _, mut verifier_socket) = prover_task.await??;
|
||||
let mut prover = prover_task.await??;
|
||||
|
||||
let transcript = prover.transcript().clone();
|
||||
let mut prove_config_builder = ProveConfig::builder(&transcript);
|
||||
@@ -164,56 +156,52 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
let prove_config = prove_config_builder.build()?;
|
||||
|
||||
// MPC-TLS prove
|
||||
let prover_output = prover.prove(&prove_config, &mut verifier_socket).await?;
|
||||
prover.close(&mut verifier_socket).await?;
|
||||
let prover_output = prover.prove(&prove_config).await?;
|
||||
prover.close().await?;
|
||||
|
||||
// Prove birthdate is more than 18 years ago.
|
||||
let received_commitments = received_commitments(&prover_output.transcript_commitments);
|
||||
let received_commitment = received_commitments
|
||||
.first()
|
||||
.ok_or_else(|| anyhow::anyhow!("No received commitments found"))?; // committed hash (of date of birth string)
|
||||
.ok_or("No received commitments found")?; // committed hash (of date of birth string)
|
||||
let received_secrets = received_secrets(&prover_output.transcript_secrets);
|
||||
let received_secret = received_secrets
|
||||
.first()
|
||||
.ok_or_else(|| anyhow::anyhow!("No received secrets found"))?; // hash blinder
|
||||
.ok_or("No received secrets found")?; // hash blinder
|
||||
let proof_input = prepare_zk_proof_input(received, received_commitment, received_secret)?;
|
||||
let proof_bundle = generate_zk_proof(&proof_input)?;
|
||||
|
||||
// Sent zk proof bundle to verifier
|
||||
let serialized_proof = bincode::serialize(&proof_bundle)?;
|
||||
|
||||
let mut verifier_socket = verifier_socket.into_inner();
|
||||
verifier_socket.write_all(&serialized_proof).await?;
|
||||
verifier_socket.shutdown().await?;
|
||||
verifier_extra_socket.write_all(&serialized_proof).await?;
|
||||
verifier_extra_socket.shutdown().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Reveal everything from the request, except for the authorization token.
|
||||
fn reveal_request(request: &[u8], builder: &mut ProveConfigBuilder<'_>) -> Result<()> {
|
||||
fn reveal_request(
|
||||
request: &[u8],
|
||||
builder: &mut ProveConfigBuilder<'_>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let reqs = Requests::new_from_slice(request).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let req = reqs
|
||||
.first()
|
||||
.ok_or_else(|| anyhow::anyhow!("No requests found"))?;
|
||||
let req = reqs.first().ok_or("No requests found")?;
|
||||
|
||||
if req.request.method.as_str() != "GET" {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Expected GET method, found {}",
|
||||
req.request.method.as_str()
|
||||
));
|
||||
return Err(format!("Expected GET method, found {}", req.request.method.as_str()).into());
|
||||
}
|
||||
|
||||
let authorization_header = req
|
||||
.headers_with_name(header::AUTHORIZATION.as_str())
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("Authorization header not found"))?;
|
||||
.ok_or("Authorization header not found")?;
|
||||
|
||||
let start_pos = authorization_header
|
||||
.span()
|
||||
.indices()
|
||||
.min()
|
||||
.ok_or_else(|| anyhow::anyhow!("Could not find authorization header start position"))?
|
||||
.ok_or("Could not find authorization header start position")?
|
||||
+ header::AUTHORIZATION.as_str().len()
|
||||
+ 2;
|
||||
let end_pos =
|
||||
@@ -229,43 +217,38 @@ fn reveal_received(
|
||||
received: &[u8],
|
||||
builder: &mut ProveConfigBuilder<'_>,
|
||||
transcript_commitment_builder: &mut TranscriptCommitConfigBuilder,
|
||||
) -> Result<()> {
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let resp = Responses::new_from_slice(received).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let response = resp
|
||||
.first()
|
||||
.ok_or_else(|| anyhow::anyhow!("No responses found"))?;
|
||||
let body = response
|
||||
.body
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Response body not found"))?;
|
||||
let response = resp.first().ok_or("No responses found")?;
|
||||
let body = response.body.as_ref().ok_or("Response body not found")?;
|
||||
|
||||
let BodyContent::Json(json) = &body.content else {
|
||||
return Err(anyhow::anyhow!("Expected JSON body content"));
|
||||
return Err("Expected JSON body content".into());
|
||||
};
|
||||
|
||||
// reveal tax year
|
||||
let tax_year = json
|
||||
.get("tax_year")
|
||||
.ok_or_else(|| anyhow::anyhow!("tax_year field not found in JSON"))?;
|
||||
.ok_or("tax_year field not found in JSON")?;
|
||||
let start_pos = tax_year
|
||||
.span()
|
||||
.indices()
|
||||
.min()
|
||||
.ok_or_else(|| anyhow::anyhow!("Could not find tax_year start position"))?
|
||||
.ok_or("Could not find tax_year start position")?
|
||||
- 11;
|
||||
let end_pos = tax_year
|
||||
.span()
|
||||
.indices()
|
||||
.max()
|
||||
.ok_or_else(|| anyhow::anyhow!("Could not find tax_year end position"))?
|
||||
.ok_or("Could not find tax_year end position")?
|
||||
+ 1;
|
||||
builder.reveal_recv(&(start_pos..end_pos))?;
|
||||
|
||||
// commit to hash of date of birth
|
||||
let dob = json
|
||||
.get("taxpayer.date_of_birth")
|
||||
.ok_or_else(|| anyhow::anyhow!("taxpayer.date_of_birth field not found in JSON"))?;
|
||||
.ok_or("taxpayer.date_of_birth field not found in JSON")?;
|
||||
|
||||
transcript_commitment_builder.commit_recv(dob.span())?;
|
||||
|
||||
@@ -296,7 +279,7 @@ fn prepare_zk_proof_input(
|
||||
received: &[u8],
|
||||
received_commitment: &PlaintextHash,
|
||||
received_secret: &PlaintextHashSecret,
|
||||
) -> Result<ZKProofInput> {
|
||||
) -> Result<ZKProofInput, Box<dyn std::error::Error>> {
|
||||
assert_eq!(received_commitment.direction, Direction::Received);
|
||||
assert_eq!(received_commitment.hash.alg, HashAlgId::SHA256);
|
||||
|
||||
@@ -305,11 +288,11 @@ fn prepare_zk_proof_input(
|
||||
let dob_start = received_commitment
|
||||
.idx
|
||||
.min()
|
||||
.ok_or_else(|| anyhow::anyhow!("No start index for DOB"))?;
|
||||
.ok_or("No start index for DOB")?;
|
||||
let dob_end = received_commitment
|
||||
.idx
|
||||
.end()
|
||||
.ok_or_else(|| anyhow::anyhow!("No end index for DOB"))?;
|
||||
.ok_or("No end index for DOB")?;
|
||||
let dob = received[dob_start..dob_end].to_vec();
|
||||
let blinder = received_secret.blinder.as_bytes().to_vec();
|
||||
let committed_hash = hash.value.as_bytes().to_vec();
|
||||
@@ -323,10 +306,8 @@ fn prepare_zk_proof_input(
|
||||
hasher.update(&blinder);
|
||||
let computed_hash = hasher.finalize();
|
||||
|
||||
if committed_hash != computed_hash.as_ref() as &[u8] {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Computed hash does not match committed hash"
|
||||
));
|
||||
if committed_hash != computed_hash.as_slice() {
|
||||
return Err("Computed hash does not match committed hash".into());
|
||||
}
|
||||
|
||||
Ok(ZKProofInput {
|
||||
@@ -337,7 +318,9 @@ fn prepare_zk_proof_input(
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_zk_proof(proof_input: &ZKProofInput) -> Result<ZKProofBundle> {
|
||||
fn generate_zk_proof(
|
||||
proof_input: &ZKProofInput,
|
||||
) -> Result<ZKProofBundle, Box<dyn std::error::Error>> {
|
||||
tracing::info!("🔒 Generating ZK proof with Noir...");
|
||||
|
||||
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
|
||||
@@ -346,7 +329,7 @@ fn generate_zk_proof(proof_input: &ZKProofInput) -> Result<ZKProofBundle> {
|
||||
let json: Value = serde_json::from_str(PROGRAM_JSON)?;
|
||||
let bytecode = json["bytecode"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("bytecode field not found in program.json"))?;
|
||||
.ok_or("bytecode field not found in program.json")?;
|
||||
|
||||
let mut inputs: Vec<String> = vec![];
|
||||
inputs.push(proof_input.proof_date.day().to_string());
|
||||
@@ -371,17 +354,16 @@ fn generate_zk_proof(proof_input: &ZKProofInput) -> Result<ZKProofBundle> {
|
||||
tracing::debug!("Witness inputs {:?}", inputs);
|
||||
|
||||
let input_refs: Vec<&str> = inputs.iter().map(String::as_str).collect();
|
||||
let witness = from_vec_str_to_witness_map(input_refs).map_err(|e| anyhow::anyhow!(e))?;
|
||||
let witness = from_vec_str_to_witness_map(input_refs)?;
|
||||
|
||||
// Setup SRS
|
||||
setup_srs_from_bytecode(bytecode, None, false).map_err(|e| anyhow::anyhow!(e))?;
|
||||
setup_srs_from_bytecode(bytecode, None, false)?;
|
||||
|
||||
// Verification key
|
||||
let vk = get_ultra_honk_verification_key(bytecode, false).map_err(|e| anyhow::anyhow!(e))?;
|
||||
let vk = get_ultra_honk_verification_key(bytecode, false)?;
|
||||
|
||||
// Generate proof
|
||||
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)?;
|
||||
tracing::info!("✅ Proof generated ({} bytes)", proof.len());
|
||||
|
||||
let proof_bundle = ZKProofBundle { vk, proof };
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
use crate::types::received_commitments;
|
||||
|
||||
use super::types::ZKProofBundle;
|
||||
use anyhow::Result;
|
||||
use chrono::{Local, NaiveDate};
|
||||
use noir::barretenberg::verify::{get_ultra_honk_verification_key, verify_ultra_honk};
|
||||
use serde_json::Value;
|
||||
use tls_server_fixture::CA_CERT_DER;
|
||||
use tlsn::{
|
||||
config::{tls_commit::TlsCommitProtocolConfig, verifier::VerifierConfig},
|
||||
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
|
||||
connection::ServerName,
|
||||
hash::HashAlgId,
|
||||
transcript::{Direction, PartialTranscript},
|
||||
verifier::{Verifier, VerifierOutput},
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
|
||||
};
|
||||
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
|
||||
use tlsn_server_fixture_certs::SERVER_DOMAIN;
|
||||
@@ -20,100 +18,60 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use tracing::instrument;
|
||||
|
||||
#[instrument(skip(prover_socket))]
|
||||
#[instrument(skip(socket, extra_socket))]
|
||||
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
prover_socket: T,
|
||||
) -> Result<PartialTranscript> {
|
||||
let mut prover_socket = prover_socket.compat();
|
||||
socket: T,
|
||||
mut extra_socket: T,
|
||||
) -> Result<PartialTranscript, Box<dyn std::error::Error>> {
|
||||
// Set up Verifier.
|
||||
let config_validator = ProtocolConfigValidator::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.build()?;
|
||||
|
||||
let verifier = Verifier::new(
|
||||
VerifierConfig::builder()
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
);
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
let verifier_config = VerifierConfig::builder()
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.protocol_config_validator(config_validator)
|
||||
.build()?;
|
||||
|
||||
// Validate the proposed configuration and then run the TLS commitment protocol.
|
||||
let verifier = verifier.commit(&mut prover_socket).await?;
|
||||
let verifier = Verifier::new(verifier_config);
|
||||
|
||||
// This is the opportunity to ensure the prover does not attempt to overload the
|
||||
// verifier.
|
||||
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
|
||||
{
|
||||
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
|
||||
Some("max_sent_data is too large")
|
||||
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
|
||||
Some("max_recv_data is too large")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
Some("expecting to use MPC-TLS")
|
||||
};
|
||||
|
||||
if reject.is_some() {
|
||||
verifier.reject(&mut prover_socket, reject).await?;
|
||||
return Err(anyhow::anyhow!("protocol configuration rejected"));
|
||||
}
|
||||
|
||||
// Runs the TLS commitment protocol to completion.
|
||||
let verifier = verifier
|
||||
.accept(&mut prover_socket)
|
||||
.await?
|
||||
.run(&mut prover_socket)
|
||||
// Receive authenticated data.
|
||||
let VerifierOutput {
|
||||
server_name,
|
||||
transcript,
|
||||
transcript_commitments,
|
||||
..
|
||||
} = verifier
|
||||
.verify(socket.compat(), &VerifyConfig::default())
|
||||
.await?;
|
||||
|
||||
// Validate the proving request and then verify.
|
||||
let verifier = verifier.verify(&mut prover_socket).await?;
|
||||
let request = verifier.request();
|
||||
|
||||
if !request.server_identity() || request.reveal().is_none() {
|
||||
let verifier = verifier
|
||||
.reject(
|
||||
&mut prover_socket,
|
||||
Some("expecting to verify the server name and transcript data"),
|
||||
)
|
||||
.await?;
|
||||
verifier.close(&mut prover_socket).await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"prover did not reveal the server name and transcript data"
|
||||
));
|
||||
}
|
||||
|
||||
let (
|
||||
VerifierOutput {
|
||||
server_name,
|
||||
transcript,
|
||||
transcript_commitments,
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier.accept(&mut prover_socket).await?;
|
||||
|
||||
verifier.close(&mut prover_socket).await?;
|
||||
let server_name = server_name.expect("server name should be present");
|
||||
let transcript = transcript.expect("transcript should be present");
|
||||
let server_name = server_name.ok_or("Prover should have revealed server name")?;
|
||||
let transcript = transcript.ok_or("Prover should have revealed transcript data")?;
|
||||
|
||||
// Create hash commitment for the date of birth field from the response
|
||||
let sent = transcript.sent_unsafe().to_vec();
|
||||
let sent_data = String::from_utf8(sent.clone())
|
||||
.map_err(|e| anyhow::anyhow!("Verifier expected valid UTF-8 sent data: {e}"))?;
|
||||
.map_err(|e| format!("Verifier expected valid UTF-8 sent data: {}", e))?;
|
||||
|
||||
if !sent_data.contains(SERVER_DOMAIN) {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Verification failed: Expected host {SERVER_DOMAIN} not found in sent data"
|
||||
));
|
||||
return Err(format!(
|
||||
"Verification failed: Expected host {} not found in sent data",
|
||||
SERVER_DOMAIN
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Check received data.
|
||||
let received_commitments = received_commitments(&transcript_commitments);
|
||||
let received_commitment = received_commitments
|
||||
.first()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing hash commitment"))?;
|
||||
.ok_or("Missing received hash commitment")?;
|
||||
|
||||
assert!(received_commitment.direction == Direction::Received);
|
||||
assert!(received_commitment.hash.alg == HashAlgId::SHA256);
|
||||
@@ -123,41 +81,39 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
// Check Session info: server name.
|
||||
let ServerName::Dns(server_name) = server_name;
|
||||
if server_name.as_str() != SERVER_DOMAIN {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Server name mismatch: expected {SERVER_DOMAIN}, got {}",
|
||||
return Err(format!(
|
||||
"Server name mismatch: expected {}, got {}",
|
||||
SERVER_DOMAIN,
|
||||
server_name.as_str()
|
||||
));
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Receive ZKProof information from prover
|
||||
let mut buf = Vec::new();
|
||||
|
||||
let mut prover_socket = prover_socket.into_inner();
|
||||
prover_socket.read_to_end(&mut buf).await?;
|
||||
extra_socket.read_to_end(&mut buf).await?;
|
||||
|
||||
if buf.is_empty() {
|
||||
return Err(anyhow::anyhow!("No ZK proof data received from prover"));
|
||||
return Err("No ZK proof data received from prover".into());
|
||||
}
|
||||
|
||||
let msg: ZKProofBundle = bincode::deserialize(&buf)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to deserialize ZK proof bundle: {e}"))?;
|
||||
.map_err(|e| format!("Failed to deserialize ZK proof bundle: {}", e))?;
|
||||
|
||||
// Verify zk proof
|
||||
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
|
||||
let json: Value = serde_json::from_str(PROGRAM_JSON)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse Noir circuit: {e}"))?;
|
||||
.map_err(|e| format!("Failed to parse Noir circuit: {}", e))?;
|
||||
|
||||
let bytecode = json["bytecode"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Bytecode field missing in noir.json"))?;
|
||||
.ok_or("Bytecode field missing in noir.json")?;
|
||||
|
||||
let vk = get_ultra_honk_verification_key(bytecode, false)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to get verification key: {e}"))?;
|
||||
.map_err(|e| format!("Failed to get verification key: {}", e))?;
|
||||
|
||||
if vk != msg.vk {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Verification key mismatch between computed and provided by prover"
|
||||
));
|
||||
return Err("Verification key mismatch between computed and provided by prover".into());
|
||||
}
|
||||
|
||||
let proof = msg.proof.clone();
|
||||
@@ -169,10 +125,12 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
// * and 32*32 bytes for the hash
|
||||
let min_bytes = (32 + 3) * 32;
|
||||
if proof.len() < min_bytes {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Proof too short: expected at least {min_bytes} bytes, got {}",
|
||||
return Err(format!(
|
||||
"Proof too short: expected at least {} bytes, got {}",
|
||||
min_bytes,
|
||||
proof.len()
|
||||
));
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Check that the proof date is correctly included in the proof
|
||||
@@ -181,12 +139,14 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
let proof_date_year: i32 = i32::from_be_bytes(proof[92..96].try_into()?);
|
||||
let proof_date_from_proof =
|
||||
NaiveDate::from_ymd_opt(proof_date_year, proof_date_month, proof_date_day)
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid proof date in proof"))?;
|
||||
.ok_or("Invalid proof date in proof")?;
|
||||
let today = Local::now().date_naive();
|
||||
if (today - proof_date_from_proof).num_days() < 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"The proof date can only be today or in the past: provided {proof_date_from_proof}, today {today}"
|
||||
));
|
||||
return Err(format!(
|
||||
"The proof date can only be today or in the past: provided {}, today {}",
|
||||
proof_date_from_proof, today
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Check that the committed hash in the proof matches the hash from the
|
||||
@@ -204,9 +164,7 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
hex::encode(&committed_hash_in_proof),
|
||||
hex::encode(&expected_hash)
|
||||
);
|
||||
return Err(anyhow::anyhow!(
|
||||
"Hash in proof does not match committed hash in MPC-TLS"
|
||||
));
|
||||
return Err("Hash in proof does not match committed hash in MPC-TLS".into());
|
||||
}
|
||||
tracing::info!(
|
||||
"✅ The hash in the proof matches the committed hash in MPC-TLS ({})",
|
||||
@@ -215,10 +173,10 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
|
||||
// Finally verify the proof
|
||||
let is_valid = verify_ultra_honk(msg.proof, msg.vk)
|
||||
.map_err(|e| anyhow::anyhow!("ZKProof Verification failed: {e}"))?;
|
||||
.map_err(|e| format!("ZKProof Verification failed: {}", e))?;
|
||||
if !is_valid {
|
||||
tracing::error!("❌ Age verification ZKProof failed to verify");
|
||||
return Err(anyhow::anyhow!("Age verification ZKProof failed to verify"));
|
||||
return Err("Age verification ZKProof failed to verify".into());
|
||||
}
|
||||
tracing::info!("✅ Age verification ZKProof successfully verified");
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tlsn-formats"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
|
||||
@@ -1,59 +1,51 @@
|
||||
#### Default Representative Benchmarks ####
|
||||
#
|
||||
# This benchmark measures TLSNotary performance on three representative network scenarios.
|
||||
# Each scenario is run multiple times to produce statistical metrics (median, std dev, etc.)
|
||||
# rather than plots. Use this for quick performance checks and CI regression testing.
|
||||
#
|
||||
# Payload sizes:
|
||||
# - upload-size: 1KB (typical HTTP request)
|
||||
# - download-size: 2KB (typical HTTP response/API data)
|
||||
#
|
||||
# Network scenarios are chosen to represent real-world user conditions where
|
||||
# TLSNotary is primarily bottlenecked by upload bandwidth.
|
||||
|
||||
#### Cable/DSL Home Internet ####
|
||||
# Most common residential internet connection
|
||||
# - Asymmetric: high download, limited upload (typical bottleneck)
|
||||
# - Upload bandwidth: 20 Mbps (realistic cable/DSL upload speed)
|
||||
# - Latency: 20ms (typical ISP latency)
|
||||
#### Latency ####
|
||||
|
||||
[[group]]
|
||||
name = "cable"
|
||||
bandwidth = 20
|
||||
protocol_latency = 20
|
||||
upload-size = 1024
|
||||
download-size = 2048
|
||||
name = "latency"
|
||||
bandwidth = 1000
|
||||
|
||||
[[bench]]
|
||||
group = "cable"
|
||||
|
||||
#### Mobile 5G ####
|
||||
# Modern mobile connection with good coverage
|
||||
# - Upload bandwidth: 30 Mbps (typical 5G upload in good conditions)
|
||||
# - Latency: 30ms (higher than wired due to mobile tower hops)
|
||||
|
||||
[[group]]
|
||||
name = "mobile_5g"
|
||||
bandwidth = 30
|
||||
protocol_latency = 30
|
||||
upload-size = 1024
|
||||
download-size = 2048
|
||||
group = "latency"
|
||||
protocol_latency = 10
|
||||
|
||||
[[bench]]
|
||||
group = "mobile_5g"
|
||||
group = "latency"
|
||||
protocol_latency = 25
|
||||
|
||||
#### Fiber Home Internet ####
|
||||
# High-end residential connection (best case scenario)
|
||||
# - Symmetric: equal upload/download bandwidth
|
||||
# - Upload bandwidth: 100 Mbps (typical fiber upload)
|
||||
# - Latency: 15ms (lower latency than cable)
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 50
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 100
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 200
|
||||
|
||||
#### Bandwidth ####
|
||||
|
||||
[[group]]
|
||||
name = "fiber"
|
||||
name = "bandwidth"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 10
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 50
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 100
|
||||
protocol_latency = 15
|
||||
upload-size = 1024
|
||||
download-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "fiber"
|
||||
group = "bandwidth"
|
||||
bandwidth = 250
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 1000
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
#### Bandwidth Sweep Benchmark ####
|
||||
#
|
||||
# Measures how network bandwidth affects TLSNotary runtime.
|
||||
# Keeps latency and payload sizes fixed while varying upload bandwidth.
|
||||
#
|
||||
# Fixed parameters:
|
||||
# - Latency: 25ms (typical internet latency)
|
||||
# - Upload: 1KB (typical request)
|
||||
# - Download: 2KB (typical response)
|
||||
#
|
||||
# Variable: Bandwidth from 5 Mbps to 1000 Mbps
|
||||
#
|
||||
# Use this to plot "Bandwidth vs Runtime" and understand bandwidth sensitivity.
|
||||
# Focus on upload bandwidth as TLSNotary is primarily upload-bottlenecked
|
||||
|
||||
[[group]]
|
||||
name = "bandwidth_sweep"
|
||||
protocol_latency = 25
|
||||
upload-size = 1024
|
||||
download-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 5
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 10
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 20
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 50
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 100
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 250
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 500
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth_sweep"
|
||||
bandwidth = 1000
|
||||
@@ -1,53 +0,0 @@
|
||||
#### Download Size Sweep Benchmark ####
|
||||
#
|
||||
# Measures how download payload size affects TLSNotary runtime.
|
||||
# Keeps network conditions fixed while varying the response size.
|
||||
#
|
||||
# Fixed parameters:
|
||||
# - Bandwidth: 100 Mbps (typical good connection)
|
||||
# - Latency: 25ms (typical internet latency)
|
||||
# - Upload: 1KB (typical request size)
|
||||
#
|
||||
# Variable: Download size from 1KB to 100KB
|
||||
#
|
||||
# Use this to plot "Download Size vs Runtime" and understand how much data
|
||||
# TLSNotary can efficiently notarize. Useful for determining optimal
|
||||
# chunking strategies for large responses.
|
||||
|
||||
[[group]]
|
||||
name = "download_sweep"
|
||||
bandwidth = 100
|
||||
protocol_latency = 25
|
||||
upload-size = 1024
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 1024
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 5120
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 10240
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 20480
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 30720
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 40960
|
||||
|
||||
[[bench]]
|
||||
group = "download_sweep"
|
||||
download-size = 51200
|
||||
@@ -1,47 +0,0 @@
|
||||
#### Latency Sweep Benchmark ####
|
||||
#
|
||||
# Measures how network latency affects TLSNotary runtime.
|
||||
# Keeps bandwidth and payload sizes fixed while varying protocol latency.
|
||||
#
|
||||
# Fixed parameters:
|
||||
# - Bandwidth: 100 Mbps (typical good connection)
|
||||
# - Upload: 1KB (typical request)
|
||||
# - Download: 2KB (typical response)
|
||||
#
|
||||
# Variable: Protocol latency from 10ms to 200ms
|
||||
#
|
||||
# Use this to plot "Latency vs Runtime" and understand latency sensitivity.
|
||||
|
||||
[[group]]
|
||||
name = "latency_sweep"
|
||||
bandwidth = 100
|
||||
upload-size = 1024
|
||||
download-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "latency_sweep"
|
||||
protocol_latency = 10
|
||||
|
||||
[[bench]]
|
||||
group = "latency_sweep"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "latency_sweep"
|
||||
protocol_latency = 50
|
||||
|
||||
[[bench]]
|
||||
group = "latency_sweep"
|
||||
protocol_latency = 75
|
||||
|
||||
[[bench]]
|
||||
group = "latency_sweep"
|
||||
protocol_latency = 100
|
||||
|
||||
[[bench]]
|
||||
group = "latency_sweep"
|
||||
protocol_latency = 150
|
||||
|
||||
[[bench]]
|
||||
group = "latency_sweep"
|
||||
protocol_latency = 200
|
||||
@@ -9,7 +9,6 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
|
||||
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
|
||||
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
|
||||
pub const DEFAULT_MEMORY_PROFILE: bool = false;
|
||||
pub const DEFAULT_REVEAL_ALL: bool = false;
|
||||
|
||||
pub const WARM_UP_BENCH: Bench = Bench {
|
||||
group: None,
|
||||
@@ -21,7 +20,6 @@ pub const WARM_UP_BENCH: Bench = Bench {
|
||||
download_size: 4096,
|
||||
defer_decryption: true,
|
||||
memory_profile: false,
|
||||
reveal_all: true,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -81,8 +79,6 @@ pub struct BenchGroupItem {
|
||||
pub defer_decryption: Option<bool>,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: Option<bool>,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -101,8 +97,6 @@ pub struct BenchItem {
|
||||
pub defer_decryption: Option<bool>,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: Option<bool>,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: Option<bool>,
|
||||
}
|
||||
|
||||
impl BenchItem {
|
||||
@@ -138,10 +132,6 @@ impl BenchItem {
|
||||
if self.memory_profile.is_none() {
|
||||
self.memory_profile = group.memory_profile;
|
||||
}
|
||||
|
||||
if self.reveal_all.is_none() {
|
||||
self.reveal_all = group.reveal_all;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bench(&self) -> Bench {
|
||||
@@ -155,7 +145,6 @@ impl BenchItem {
|
||||
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
|
||||
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
|
||||
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
|
||||
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,8 +164,6 @@ pub struct Bench {
|
||||
pub defer_decryption: bool,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: bool,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -22,10 +22,7 @@ pub enum CmdOutput {
|
||||
GetTests(Vec<String>),
|
||||
Test(TestOutput),
|
||||
Bench(BenchOutput),
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
Fail {
|
||||
reason: Option<String>,
|
||||
},
|
||||
Fail { reason: Option<String> },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -5,15 +5,9 @@ use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
|
||||
|
||||
use harness_core::bench::{Bench, ProverMetrics};
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
prover::ProverConfig,
|
||||
tls::TlsClientConfig,
|
||||
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
|
||||
},
|
||||
config::{CertificateDer, ProtocolConfig, RootCertStore},
|
||||
connection::ServerName,
|
||||
prover::Prover,
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
|
||||
};
|
||||
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
|
||||
|
||||
@@ -23,54 +17,48 @@ use crate::{
|
||||
};
|
||||
|
||||
pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<ProverMetrics> {
|
||||
let mut verifier_io = Meter::new(provider.provide_proto_io().await?);
|
||||
let mut server_io = provider.provide_server_io().await?;
|
||||
let verifier_io = Meter::new(provider.provide_proto_io().await?);
|
||||
|
||||
let sent = verifier_io.sent();
|
||||
let recv = verifier_io.recv();
|
||||
|
||||
let prover = Prover::new(ProverConfig::builder().build()?);
|
||||
let mut builder = ProtocolConfig::builder();
|
||||
builder.max_sent_data(config.upload_size);
|
||||
|
||||
builder.defer_decryption_from_start(config.defer_decryption);
|
||||
if !config.defer_decryption {
|
||||
builder.max_recv_data_online(config.download_size + RECV_PADDING);
|
||||
}
|
||||
builder.max_recv_data(config.download_size + RECV_PADDING);
|
||||
|
||||
let protocol_config = builder.build()?;
|
||||
|
||||
let mut tls_config_builder = TlsConfig::builder();
|
||||
tls_config_builder.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
});
|
||||
let tls_config = tls_config_builder.build()?;
|
||||
|
||||
let prover = Prover::new(
|
||||
ProverConfig::builder()
|
||||
.tls_config(tls_config)
|
||||
.protocol_config(protocol_config)
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
|
||||
.build()?,
|
||||
);
|
||||
|
||||
let time_start = web_time::Instant::now();
|
||||
|
||||
let prover = prover
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
.protocol({
|
||||
let mut builder = MpcTlsConfig::builder()
|
||||
.max_sent_data(config.upload_size)
|
||||
.defer_decryption_from_start(config.defer_decryption);
|
||||
|
||||
if !config.defer_decryption {
|
||||
builder = builder.max_recv_data_online(config.download_size + RECV_PADDING);
|
||||
}
|
||||
|
||||
builder
|
||||
.max_recv_data(config.download_size + RECV_PADDING)
|
||||
.build()
|
||||
}?)
|
||||
.build()?,
|
||||
&mut verifier_io,
|
||||
)
|
||||
.await?;
|
||||
let prover = prover.setup(verifier_io).await?;
|
||||
|
||||
let time_preprocess = time_start.elapsed().as_millis();
|
||||
let time_start_online = web_time::Instant::now();
|
||||
let uploaded_preprocess = sent.load(Ordering::Relaxed);
|
||||
let downloaded_preprocess = recv.load(Ordering::Relaxed);
|
||||
|
||||
let (mut conn, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
)?;
|
||||
let (mut conn, prover_fut) = prover.connect(provider.provide_server_io().await?).await?;
|
||||
|
||||
let mut prover = prover.connect(&mut server_io, &mut verifier_io);
|
||||
|
||||
futures::try_join!(
|
||||
let (_, mut prover) = futures::try_join!(
|
||||
async {
|
||||
let request = format!(
|
||||
"GET /bytes?size={} HTTP/1.1\r\nConnection: close\r\nData: {}\r\n\r\n",
|
||||
@@ -87,9 +75,8 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
|
||||
Ok(())
|
||||
},
|
||||
(&mut prover).map_err(anyhow::Error::from)
|
||||
prover_fut.map_err(anyhow::Error::from)
|
||||
)?;
|
||||
let mut prover = prover.finish()?;
|
||||
|
||||
let time_online = time_start_online.elapsed().as_millis();
|
||||
let uploaded_online = sent.load(Ordering::Relaxed) - uploaded_preprocess;
|
||||
@@ -99,28 +86,15 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
|
||||
// When reveal_all is false (the default), we exclude 1 byte to avoid the
|
||||
// reveal-all optimization and benchmark the realistic ZK authentication path.
|
||||
let reveal_sent_range = if config.reveal_all {
|
||||
0..sent_len
|
||||
} else {
|
||||
0..sent_len.saturating_sub(1)
|
||||
};
|
||||
let reveal_recv_range = if config.reveal_all {
|
||||
0..recv_len
|
||||
} else {
|
||||
0..recv_len.saturating_sub(1)
|
||||
};
|
||||
|
||||
builder
|
||||
.server_identity()
|
||||
.reveal_sent(&reveal_sent_range)?
|
||||
.reveal_recv(&reveal_recv_range)?;
|
||||
.reveal_sent(&(0..sent_len))?
|
||||
.reveal_recv(&(0..recv_len))?;
|
||||
|
||||
let prove_config = builder.build()?;
|
||||
let config = builder.build()?;
|
||||
|
||||
prover.prove(&prove_config, &mut verifier_io).await?;
|
||||
prover.close(&mut verifier_io).await?;
|
||||
prover.prove(&config).await?;
|
||||
prover.close().await?;
|
||||
|
||||
let time_total = time_start.elapsed().as_millis();
|
||||
|
||||
|
||||
@@ -2,38 +2,34 @@ use anyhow::Result;
|
||||
|
||||
use harness_core::bench::Bench;
|
||||
use tlsn::{
|
||||
config::verifier::VerifierConfig,
|
||||
verifier::Verifier,
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
|
||||
verifier::{Verifier, VerifierConfig, VerifyConfig},
|
||||
};
|
||||
use tlsn_server_fixture_certs::CA_CERT_DER;
|
||||
|
||||
use crate::IoProvider;
|
||||
use crate::{IoProvider, bench::RECV_PADDING};
|
||||
|
||||
pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()> {
|
||||
let mut prover_io = provider.provide_proto_io().await?;
|
||||
pub async fn bench_verifier(provider: &IoProvider, config: &Bench) -> Result<()> {
|
||||
let mut builder = ProtocolConfigValidator::builder();
|
||||
builder
|
||||
.max_sent_data(config.upload_size)
|
||||
.max_recv_data(config.download_size + RECV_PADDING);
|
||||
|
||||
let protocol_config = builder.build()?;
|
||||
|
||||
let verifier = Verifier::new(
|
||||
VerifierConfig::builder()
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.protocol_config_validator(protocol_config)
|
||||
.build()?,
|
||||
);
|
||||
|
||||
let verifier = verifier
|
||||
.commit(provider.provide_proto_io().await?)
|
||||
.await?
|
||||
.accept(&mut prover_io)
|
||||
.await?
|
||||
.run(&mut prover_io)
|
||||
.await?;
|
||||
let (_, verifier) = verifier
|
||||
.verify(&mut prover_io)
|
||||
.await?
|
||||
.accept(&mut prover_io)
|
||||
.await?;
|
||||
verifier.close(&mut prover_io).await?;
|
||||
let verifier = verifier.setup(provider.provide_proto_io().await?).await?;
|
||||
let mut verifier = verifier.run().await?;
|
||||
verifier.verify(&VerifyConfig::default()).await?;
|
||||
verifier.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
prover::ProverConfig,
|
||||
tls::TlsClientConfig,
|
||||
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
|
||||
verifier::VerifierConfig,
|
||||
},
|
||||
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
|
||||
connection::ServerName,
|
||||
hash::HashAlgId,
|
||||
prover::Prover,
|
||||
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
|
||||
transcript::{TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind},
|
||||
verifier::{Verifier, VerifierOutput},
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
|
||||
};
|
||||
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
|
||||
|
||||
@@ -28,41 +21,39 @@ const MAX_RECV_DATA: usize = 1 << 11;
|
||||
crate::test!("basic", prover, verifier);
|
||||
|
||||
async fn prover(provider: &IoProvider) {
|
||||
let mut verifier_io = provider.provide_proto_io().await.unwrap();
|
||||
let mut tls_config_builder = TlsConfig::builder();
|
||||
tls_config_builder.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
});
|
||||
|
||||
let prover = Prover::new(ProverConfig::builder().build().unwrap())
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
.protocol(
|
||||
MpcTlsConfig::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.defer_decryption_from_start(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.build()
|
||||
.unwrap(),
|
||||
&mut verifier_io,
|
||||
)
|
||||
let tls_config = tls_config_builder.build().unwrap();
|
||||
|
||||
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap());
|
||||
let prover = Prover::new(
|
||||
ProverConfig::builder()
|
||||
.server_name(server_name)
|
||||
.tls_config(tls_config)
|
||||
.protocol_config(
|
||||
ProtocolConfig::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.defer_decryption_from_start(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.setup(provider.provide_proto_io().await.unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (tls_connection, prover_fut) = prover
|
||||
.connect(provider.provide_server_io().await.unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_io = provider.provide_server_io().await.unwrap();
|
||||
|
||||
let (tls_connection, prover) = prover
|
||||
.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let prover_task = spawn(prover.run(server_io, verifier_io));
|
||||
let prover_task = spawn(prover_fut);
|
||||
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(FuturesIo::new(tls_connection))
|
||||
@@ -89,7 +80,7 @@ async fn prover(provider: &IoProvider) {
|
||||
|
||||
let _ = response.into_body().collect().await.unwrap().to_bytes();
|
||||
|
||||
let (mut prover, _, mut verifier_io) = prover_task.await.unwrap().unwrap();
|
||||
let mut prover = prover_task.await.unwrap().unwrap();
|
||||
|
||||
let (sent_len, recv_len) = prover.transcript().len();
|
||||
|
||||
@@ -116,48 +107,39 @@ async fn prover(provider: &IoProvider) {
|
||||
|
||||
let config = builder.build().unwrap();
|
||||
|
||||
prover.prove(&config, &mut verifier_io).await.unwrap();
|
||||
prover.close(&mut verifier_io).await.unwrap();
|
||||
prover.prove(&config).await.unwrap();
|
||||
prover.close().await.unwrap();
|
||||
}
|
||||
|
||||
async fn verifier(provider: &IoProvider) {
|
||||
let mut prover_io = provider.provide_proto_io().await.unwrap();
|
||||
|
||||
let config = VerifierConfig::builder()
|
||||
.protocol_config_validator(
|
||||
ProtocolConfigValidator::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let verifier = Verifier::new(config)
|
||||
.commit(&mut prover_io)
|
||||
.await
|
||||
.unwrap()
|
||||
.accept(&mut prover_io)
|
||||
.await
|
||||
.unwrap()
|
||||
.run(&mut prover_io)
|
||||
let verifier = Verifier::new(config);
|
||||
|
||||
let VerifierOutput {
|
||||
server_name,
|
||||
transcript_commitments,
|
||||
..
|
||||
} = verifier
|
||||
.verify(
|
||||
provider.provide_proto_io().await.unwrap(),
|
||||
&VerifyConfig::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (
|
||||
VerifierOutput {
|
||||
server_name,
|
||||
transcript_commitments,
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier
|
||||
.verify(&mut prover_io)
|
||||
.await
|
||||
.unwrap()
|
||||
.accept(&mut prover_io)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
verifier.close(&mut prover_io).await.unwrap();
|
||||
|
||||
let ServerName::Dns(server_name) = server_name.unwrap();
|
||||
|
||||
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
|
||||
|
||||
@@ -22,7 +22,6 @@ clap = { workspace = true, features = ["derive", "env"] }
|
||||
csv = { version = "1.3" }
|
||||
duct = { version = "1" }
|
||||
futures = { workspace = true }
|
||||
indicatif = { version = "0.17" }
|
||||
ipnet = { workspace = true }
|
||||
serio = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
@@ -16,10 +16,6 @@ pub struct Cli {
|
||||
/// Subnet to assign harness network interfaces.
|
||||
#[arg(long, default_value = "10.250.0.0/24", env = "SUBNET")]
|
||||
pub subnet: Ipv4Net,
|
||||
/// Run browser in headed mode (visible window) for debugging.
|
||||
/// Works with both X11 and Wayland.
|
||||
#[arg(long)]
|
||||
pub headed: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
@@ -35,13 +31,10 @@ pub enum Command {
|
||||
},
|
||||
/// runs benchmarks.
|
||||
Bench {
|
||||
/// Configuration path. Defaults to bench.toml which contains
|
||||
/// representative scenarios (cable, 5G, fiber) for quick performance
|
||||
/// checks. Use bench_*_sweep.toml files for parametric
|
||||
/// analysis.
|
||||
/// Configuration path.
|
||||
#[arg(short, long, default_value = "bench.toml")]
|
||||
config: PathBuf,
|
||||
/// Output CSV file path for detailed metrics and post-processing.
|
||||
/// Output file path.
|
||||
#[arg(short, long, default_value = "metrics.csv")]
|
||||
output: PathBuf,
|
||||
/// Number of samples to measure per benchmark. This is overridden by
|
||||
|
||||
@@ -28,9 +28,6 @@ pub struct Executor {
|
||||
ns: Namespace,
|
||||
config: ExecutorConfig,
|
||||
target: Target,
|
||||
/// Display environment variables for headed mode (X11/Wayland).
|
||||
/// Empty means headless mode.
|
||||
display_env: Vec<String>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
@@ -52,17 +49,11 @@ impl State {
|
||||
}
|
||||
|
||||
impl Executor {
|
||||
pub fn new(
|
||||
ns: Namespace,
|
||||
config: ExecutorConfig,
|
||||
target: Target,
|
||||
display_env: Vec<String>,
|
||||
) -> Self {
|
||||
pub fn new(ns: Namespace, config: ExecutorConfig, target: Target) -> Self {
|
||||
Self {
|
||||
ns,
|
||||
config,
|
||||
target,
|
||||
display_env,
|
||||
state: State::Init,
|
||||
}
|
||||
}
|
||||
@@ -129,49 +120,23 @@ impl Executor {
|
||||
let tmp = duct::cmd!("mktemp", "-d").read()?;
|
||||
let tmp = tmp.trim();
|
||||
|
||||
let headed = !self.display_env.is_empty();
|
||||
|
||||
// Build command args based on headed/headless mode
|
||||
let mut args: Vec<String> = vec![
|
||||
"ip".into(),
|
||||
"netns".into(),
|
||||
"exec".into(),
|
||||
self.ns.name().into(),
|
||||
];
|
||||
|
||||
if headed {
|
||||
// For headed mode: drop back to the current user and pass display env vars
|
||||
// This allows the browser to connect to X11/Wayland while in the namespace
|
||||
let user =
|
||||
std::env::var("USER").context("USER environment variable not set")?;
|
||||
args.extend(["sudo".into(), "-E".into(), "-u".into(), user, "env".into()]);
|
||||
args.extend(self.display_env.clone());
|
||||
}
|
||||
|
||||
args.push(chrome_path.to_string_lossy().into());
|
||||
args.push(format!("--remote-debugging-port={PORT_BROWSER}"));
|
||||
|
||||
if headed {
|
||||
// Headed mode: no headless, add flags to suppress first-run dialogs
|
||||
args.extend(["--no-first-run".into(), "--no-default-browser-check".into()]);
|
||||
} else {
|
||||
// Headless mode: original flags
|
||||
args.extend([
|
||||
"--headless".into(),
|
||||
"--disable-dev-shm-usage".into(),
|
||||
"--disable-gpu".into(),
|
||||
"--disable-cache".into(),
|
||||
"--disable-application-cache".into(),
|
||||
]);
|
||||
}
|
||||
|
||||
args.extend([
|
||||
"--no-sandbox".into(),
|
||||
let process = duct::cmd!(
|
||||
"sudo",
|
||||
"ip",
|
||||
"netns",
|
||||
"exec",
|
||||
self.ns.name(),
|
||||
chrome_path,
|
||||
format!("--remote-debugging-port={PORT_BROWSER}"),
|
||||
"--headless",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-gpu",
|
||||
"--disable-cache",
|
||||
"--disable-application-cache",
|
||||
"--no-sandbox",
|
||||
format!("--user-data-dir={tmp}"),
|
||||
"--allowed-ips=10.250.0.1".into(),
|
||||
]);
|
||||
|
||||
let process = duct::cmd("sudo", &args);
|
||||
format!("--allowed-ips=10.250.0.1"),
|
||||
);
|
||||
|
||||
let process = if !cfg!(feature = "debug") {
|
||||
process.stderr_capture().stdout_capture().start()?
|
||||
|
||||
@@ -9,7 +9,7 @@ mod ws_proxy;
|
||||
#[cfg(feature = "debug")]
|
||||
mod debug_prelude;
|
||||
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
@@ -22,7 +22,6 @@ use harness_core::{
|
||||
rpc::{BenchCmd, TestCmd},
|
||||
test::TestStatus,
|
||||
};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
|
||||
use cli::{Cli, Command};
|
||||
use executor::Executor;
|
||||
@@ -33,67 +32,18 @@ use crate::debug_prelude::*;
|
||||
|
||||
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
|
||||
|
||||
/// Statistics for a benchmark configuration
|
||||
#[derive(Debug, Clone)]
|
||||
struct BenchStats {
|
||||
group: Option<String>,
|
||||
bandwidth: usize,
|
||||
latency: usize,
|
||||
upload_size: usize,
|
||||
download_size: usize,
|
||||
times: Vec<u64>,
|
||||
}
|
||||
|
||||
impl BenchStats {
|
||||
fn median(&self) -> f64 {
|
||||
let mut sorted = self.times.clone();
|
||||
sorted.sort();
|
||||
let len = sorted.len();
|
||||
if len == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
if len.is_multiple_of(2) {
|
||||
(sorted[len / 2 - 1] + sorted[len / 2]) as f64 / 2.0
|
||||
} else {
|
||||
sorted[len / 2] as f64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Print summary table of benchmark results
|
||||
fn print_bench_summary(stats: &[BenchStats]) {
|
||||
if stats.is_empty() {
|
||||
println!("\nNo benchmark results to display (only warmup was run).");
|
||||
return;
|
||||
}
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TLSNotary Benchmark Results");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!();
|
||||
|
||||
for stat in stats {
|
||||
let group_name = stat.group.as_deref().unwrap_or("unnamed");
|
||||
println!(
|
||||
"{} ({} Mbps, {}ms latency, {}KB↑ {}KB↓):",
|
||||
group_name,
|
||||
stat.bandwidth,
|
||||
stat.latency,
|
||||
stat.upload_size / 1024,
|
||||
stat.download_size / 1024
|
||||
);
|
||||
println!(" Median: {:.2}s", stat.median() / 1000.0);
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Default)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Target {
|
||||
#[default]
|
||||
Native,
|
||||
Browser,
|
||||
}
|
||||
|
||||
impl Default for Target {
|
||||
fn default() -> Self {
|
||||
Self::Native
|
||||
}
|
||||
}
|
||||
|
||||
struct Runner {
|
||||
network: Network,
|
||||
server_fixture: ServerFixture,
|
||||
@@ -105,46 +55,14 @@ struct Runner {
|
||||
started: bool,
|
||||
}
|
||||
|
||||
/// Collects display-related environment variables for headed browser mode.
|
||||
/// Works with both X11 and Wayland by collecting whichever vars are present.
|
||||
fn collect_display_env_vars() -> Vec<String> {
|
||||
const DISPLAY_VARS: &[&str] = &[
|
||||
"DISPLAY", // X11
|
||||
"XAUTHORITY", // X11 auth
|
||||
"WAYLAND_DISPLAY", // Wayland
|
||||
"XDG_RUNTIME_DIR", // Wayland runtime dir
|
||||
];
|
||||
|
||||
DISPLAY_VARS
|
||||
.iter()
|
||||
.filter_map(|&var| {
|
||||
std::env::var(var)
|
||||
.ok()
|
||||
.map(|val| format!("{}={}", var, val))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl Runner {
|
||||
fn new(cli: &Cli) -> Result<Self> {
|
||||
let Cli {
|
||||
target,
|
||||
subnet,
|
||||
headed,
|
||||
..
|
||||
} = cli;
|
||||
let Cli { target, subnet, .. } = cli;
|
||||
let current_path = std::env::current_exe().unwrap();
|
||||
let fixture_path = current_path.parent().unwrap().join("server-fixture");
|
||||
let network_config = NetworkConfig::new(*subnet);
|
||||
let network = Network::new(network_config.clone())?;
|
||||
|
||||
// Collect display env vars once if headed mode is enabled
|
||||
let display_env = if *headed {
|
||||
collect_display_env_vars()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let server_fixture =
|
||||
ServerFixture::new(fixture_path, network.ns_app().clone(), network_config.app);
|
||||
let wasm_server = WasmServer::new(
|
||||
@@ -162,7 +80,6 @@ impl Runner {
|
||||
.network_config(network_config.clone())
|
||||
.build(),
|
||||
*target,
|
||||
display_env.clone(),
|
||||
);
|
||||
let exec_v = Executor::new(
|
||||
network.ns_1().clone(),
|
||||
@@ -172,7 +89,6 @@ impl Runner {
|
||||
.network_config(network_config.clone())
|
||||
.build(),
|
||||
Target::Native,
|
||||
Vec::new(), // Verifier doesn't need display env
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
@@ -207,12 +123,6 @@ pub async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Validate --headed requires --target browser
|
||||
if cli.headed && cli.target != Target::Browser {
|
||||
anyhow::bail!("--headed can only be used with --target browser");
|
||||
}
|
||||
|
||||
let mut runner = Runner::new(&cli)?;
|
||||
|
||||
let mut exit_code = 0;
|
||||
@@ -301,12 +211,6 @@ pub async fn main() -> Result<()> {
|
||||
samples_override,
|
||||
skip_warmup,
|
||||
} => {
|
||||
// Print configuration info
|
||||
println!("TLSNotary Benchmark Harness");
|
||||
println!("Running benchmarks from: {}", config.display());
|
||||
println!("Output will be written to: {}", output.display());
|
||||
println!();
|
||||
|
||||
let items: BenchItems = toml::from_str(&std::fs::read_to_string(config)?)?;
|
||||
let output_file = std::fs::File::create(output)?;
|
||||
let mut writer = WriterBuilder::new().from_writer(output_file);
|
||||
@@ -321,34 +225,7 @@ pub async fn main() -> Result<()> {
|
||||
runner.exec_p.start().await?;
|
||||
runner.exec_v.start().await?;
|
||||
|
||||
// Create progress bar
|
||||
let pb = ProgressBar::new(benches.len() as u64);
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
|
||||
.expect("valid template")
|
||||
.progress_chars("█▓▒░ "),
|
||||
);
|
||||
|
||||
// Collect measurements for stats
|
||||
let mut measurements_by_config: HashMap<String, Vec<u64>> = HashMap::new();
|
||||
|
||||
let warmup_count = if skip_warmup { 0 } else { 3 };
|
||||
|
||||
for (idx, config) in benches.iter().enumerate() {
|
||||
let is_warmup = idx < warmup_count;
|
||||
|
||||
let group_name = if is_warmup {
|
||||
format!("Warmup {}/{}", idx + 1, warmup_count)
|
||||
} else {
|
||||
config.group.as_deref().unwrap_or("unnamed").to_string()
|
||||
};
|
||||
|
||||
pb.set_message(format!(
|
||||
"{} ({} Mbps, {}ms)",
|
||||
group_name, config.bandwidth, config.protocol_latency
|
||||
));
|
||||
|
||||
for config in benches {
|
||||
runner
|
||||
.network
|
||||
.set_proto_config(config.bandwidth, config.protocol_latency.div_ceil(2))?;
|
||||
@@ -377,73 +254,11 @@ pub async fn main() -> Result<()> {
|
||||
panic!("expected prover output");
|
||||
};
|
||||
|
||||
// Collect metrics for stats (skip warmup benches)
|
||||
if !is_warmup {
|
||||
let config_key = format!(
|
||||
"{:?}|{}|{}|{}|{}",
|
||||
config.group,
|
||||
config.bandwidth,
|
||||
config.protocol_latency,
|
||||
config.upload_size,
|
||||
config.download_size
|
||||
);
|
||||
measurements_by_config
|
||||
.entry(config_key)
|
||||
.or_default()
|
||||
.push(metrics.time_total);
|
||||
}
|
||||
|
||||
let measurement = Measurement::new(config.clone(), metrics);
|
||||
let measurement = Measurement::new(config, metrics);
|
||||
|
||||
writer.serialize(measurement)?;
|
||||
writer.flush()?;
|
||||
|
||||
pb.inc(1);
|
||||
}
|
||||
|
||||
pb.finish_with_message("Benchmarks complete");
|
||||
|
||||
// Compute and print statistics
|
||||
let mut all_stats: Vec<BenchStats> = Vec::new();
|
||||
for (key, times) in measurements_by_config {
|
||||
// Parse back the config from the key
|
||||
let parts: Vec<&str> = key.split('|').collect();
|
||||
if parts.len() >= 5 {
|
||||
let group = if parts[0] == "None" {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
parts[0]
|
||||
.trim_start_matches("Some(\"")
|
||||
.trim_end_matches("\")")
|
||||
.to_string(),
|
||||
)
|
||||
};
|
||||
let bandwidth: usize = parts[1].parse().unwrap_or(0);
|
||||
let latency: usize = parts[2].parse().unwrap_or(0);
|
||||
let upload_size: usize = parts[3].parse().unwrap_or(0);
|
||||
let download_size: usize = parts[4].parse().unwrap_or(0);
|
||||
|
||||
all_stats.push(BenchStats {
|
||||
group,
|
||||
bandwidth,
|
||||
latency,
|
||||
upload_size,
|
||||
download_size,
|
||||
times,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort stats by group name for consistent output
|
||||
all_stats.sort_by(|a, b| {
|
||||
a.group
|
||||
.cmp(&b.group)
|
||||
.then(a.latency.cmp(&b.latency))
|
||||
.then(a.bandwidth.cmp(&b.bandwidth))
|
||||
});
|
||||
|
||||
print_bench_summary(&all_stats);
|
||||
}
|
||||
Command::Serve {} => {
|
||||
runner.start_services().await?;
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "TLSNotary MPC-TLS protocol"
|
||||
keywords = ["tls", "mpc", "2pc"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
@@ -33,6 +33,7 @@ mpz-ole = { workspace = true }
|
||||
mpz-share-conversion = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
|
||||
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
|
||||
serio = { workspace = true }
|
||||
@@ -56,9 +57,9 @@ pin-project-lite = { workspace = true }
|
||||
web-time = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-common = { workspace = true, features = ["test-utils"] }
|
||||
mpz-ot = { workspace = true, features = ["ideal"] }
|
||||
mpz-ideal-vm = { workspace = true }
|
||||
mpz-ole = { workspace = true, features = ["test-utils"] }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
|
||||
cipher-crate = { package = "cipher", version = "0.4" }
|
||||
generic-array = { workspace = true }
|
||||
@@ -66,6 +67,7 @@ rand_chacha = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tls-server-fixture = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tokio-util = { workspace = true, features = ["compat"] }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
@@ -378,29 +378,15 @@ impl MpcTlsLeader {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enables or disables the decryption of any incoming messages.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
/// Defers decryption of any incoming messages.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = enable;
|
||||
|
||||
if enable {
|
||||
self.notifier.set();
|
||||
} else {
|
||||
self.notifier.clear();
|
||||
}
|
||||
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = false;
|
||||
self.notifier.clear();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns if incoming messages are decrypted.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.is_decrypting
|
||||
}
|
||||
|
||||
/// Stops the actor.
|
||||
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
|
||||
ctx.stop();
|
||||
|
||||
@@ -32,14 +32,10 @@ impl MpcTlsLeaderCtrl {
|
||||
Self { address }
|
||||
}
|
||||
|
||||
/// Enables or disables the decryption of any incoming messages.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
pub async fn enable_decryption(&self, enable: bool) -> Result<(), MpcTlsError> {
|
||||
/// Defers decryption of any incoming messages.
|
||||
pub async fn defer_decryption(&self) -> Result<(), MpcTlsError> {
|
||||
self.address
|
||||
.send(EnableDecryption { enable })
|
||||
.send(DeferDecryption)
|
||||
.await
|
||||
.map_err(MpcTlsError::actor)?
|
||||
}
|
||||
@@ -985,7 +981,7 @@ impl Handler<BackendMsgServerClosed> for MpcTlsLeader {
|
||||
}
|
||||
}
|
||||
|
||||
impl Dispatch<MpcTlsLeader> for EnableDecryption {
|
||||
impl Dispatch<MpcTlsLeader> for DeferDecryption {
|
||||
fn dispatch<R: FnOnce(Self::Return) + Send>(
|
||||
self,
|
||||
actor: &mut MpcTlsLeader,
|
||||
@@ -996,13 +992,13 @@ impl Dispatch<MpcTlsLeader> for EnableDecryption {
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<EnableDecryption> for MpcTlsLeader {
|
||||
impl Handler<DeferDecryption> for MpcTlsLeader {
|
||||
async fn handle(
|
||||
&mut self,
|
||||
msg: EnableDecryption,
|
||||
_msg: DeferDecryption,
|
||||
_ctx: &mut LudiCtx<Self>,
|
||||
) -> <EnableDecryption as Message>::Return {
|
||||
self.enable_decryption(msg.enable)
|
||||
) -> <DeferDecryption as Message>::Return {
|
||||
self.defer_decryption().await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1052,7 +1048,7 @@ pub enum MpcTlsLeaderMsg {
|
||||
BackendMsgGetNotify(BackendMsgGetNotify),
|
||||
BackendMsgIsEmpty(BackendMsgIsEmpty),
|
||||
BackendMsgServerClosed(BackendMsgServerClosed),
|
||||
DeferDecryption(EnableDecryption),
|
||||
DeferDecryption(DeferDecryption),
|
||||
Stop(Stop),
|
||||
}
|
||||
|
||||
@@ -1087,7 +1083,7 @@ pub enum MpcTlsLeaderMsgReturn {
|
||||
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
|
||||
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
|
||||
BackendMsgServerClosed(<BackendMsgServerClosed as Message>::Return),
|
||||
DeferDecryption(<EnableDecryption as Message>::Return),
|
||||
DeferDecryption(<DeferDecryption as Message>::Return),
|
||||
Stop(<Stop as Message>::Return),
|
||||
}
|
||||
|
||||
@@ -1736,25 +1732,23 @@ impl Wrap<BackendMsgServerClosed> for MpcTlsLeaderMsg {
|
||||
}
|
||||
}
|
||||
|
||||
/// Message to enable or disable the decryption of messages.
|
||||
/// Message to start deferring the decryption.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
pub struct EnableDecryption {
|
||||
pub enable: bool,
|
||||
}
|
||||
pub struct DeferDecryption;
|
||||
|
||||
impl Message for EnableDecryption {
|
||||
impl Message for DeferDecryption {
|
||||
type Return = Result<(), MpcTlsError>;
|
||||
}
|
||||
|
||||
impl From<EnableDecryption> for MpcTlsLeaderMsg {
|
||||
fn from(value: EnableDecryption) -> Self {
|
||||
impl From<DeferDecryption> for MpcTlsLeaderMsg {
|
||||
fn from(value: DeferDecryption) -> Self {
|
||||
MpcTlsLeaderMsg::DeferDecryption(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Wrap<EnableDecryption> for MpcTlsLeaderMsg {
|
||||
fn unwrap_return(ret: Self::Return) -> Result<<EnableDecryption as Message>::Return, Error> {
|
||||
impl Wrap<DeferDecryption> for MpcTlsLeaderMsg {
|
||||
fn unwrap_return(ret: Self::Return) -> Result<<DeferDecryption as Message>::Return, Error> {
|
||||
match ret {
|
||||
Self::Return::DeferDecryption(value) => Ok(value),
|
||||
_ => Err(Error::Wrapper),
|
||||
|
||||
@@ -487,7 +487,7 @@ impl RecordLayer {
|
||||
|
||||
sent_records.push(Record {
|
||||
seq: op.seq,
|
||||
typ: op.typ.into(),
|
||||
typ: op.typ,
|
||||
plaintext: op.plaintext,
|
||||
explicit_nonce: op.explicit_nonce,
|
||||
ciphertext,
|
||||
@@ -505,7 +505,7 @@ impl RecordLayer {
|
||||
|
||||
recv_records.push(Record {
|
||||
seq: op.seq,
|
||||
typ: op.typ.into(),
|
||||
typ: op.typ,
|
||||
plaintext,
|
||||
explicit_nonce: op.explicit_nonce,
|
||||
ciphertext: op.ciphertext,
|
||||
@@ -578,7 +578,7 @@ impl RecordLayer {
|
||||
|
||||
recv_records.push(Record {
|
||||
seq: op.seq,
|
||||
typ: op.typ.into(),
|
||||
typ: op.typ,
|
||||
plaintext,
|
||||
explicit_nonce: op.explicit_nonce,
|
||||
ciphertext: op.ciphertext,
|
||||
|
||||
@@ -456,8 +456,9 @@ mod tests {
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_memory_core::binary::U8;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_memory_core::{binary::U8, correlated::Delta};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_share_conversion::ideal::ideal_share_convert;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rstest::*;
|
||||
@@ -573,8 +574,13 @@ mod tests {
|
||||
}
|
||||
|
||||
fn create_vm(key: [u8; 16], iv: [u8; 4]) -> ((impl Vm<Binary>, Vars), (impl Vm<Binary>, Vars)) {
|
||||
let mut vm_0 = IdealVm::new();
|
||||
let mut vm_1 = IdealVm::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let block = Block::random(&mut rng);
|
||||
let (sender, receiver) = ideal_cot(block);
|
||||
|
||||
let delta = Delta::new(block);
|
||||
let mut vm_0 = Garbler::new(sender, [0u8; 16], delta);
|
||||
let mut vm_1 = Evaluator::new(receiver);
|
||||
|
||||
let key_ref_0 = vm_0.alloc::<Array<U8, 16>>().unwrap();
|
||||
vm_0.mark_public(key_ref_0).unwrap();
|
||||
|
||||
168
crates/mpc-tls/tests/test.rs
Normal file
168
crates/mpc-tls/tests/test.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
|
||||
use mpz_common::context::test_mt_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_memory_core::correlated::Delta;
|
||||
use mpz_ot::{
|
||||
cot::{DerandCOTReceiver, DerandCOTSender},
|
||||
ideal::rcot::ideal_rcot,
|
||||
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use tls_client::RootCertStore;
|
||||
use tls_client_async::bind_client;
|
||||
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "expensive"]
|
||||
async fn mpc_tls_test() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let config = Config::builder()
|
||||
.defer_decryption(false)
|
||||
.max_sent(1 << 13)
|
||||
.max_recv_online(1 << 13)
|
||||
.max_recv(1 << 13)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let (leader, follower) = build_pair(config);
|
||||
|
||||
tokio::try_join!(
|
||||
tokio::spawn(leader_task(leader)),
|
||||
tokio::spawn(follower_task(follower))
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn leader_task(mut leader: MpcTlsLeader) {
|
||||
leader.alloc().unwrap();
|
||||
|
||||
leader.preprocess().await.unwrap();
|
||||
|
||||
let (leader_ctrl, leader_fut) = leader.run();
|
||||
tokio::spawn(async { leader_fut.await.unwrap() });
|
||||
|
||||
let config = tls_client::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(RootCertStore {
|
||||
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
|
||||
})
|
||||
.with_no_client_auth();
|
||||
|
||||
let server_name = SERVER_DOMAIN.try_into().unwrap();
|
||||
|
||||
let client = tls_client::ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(leader_ctrl.clone()),
|
||||
server_name,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
|
||||
|
||||
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
|
||||
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: keep-alive\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 48];
|
||||
conn.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.defer_decryption().await.unwrap();
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: close\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
conn.close().await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
conn.read_to_end(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.stop().await.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn follower_task(mut follower: MpcTlsFollower) {
|
||||
follower.alloc().unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
follower.run().await.unwrap();
|
||||
}
|
||||
|
||||
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let (mut mt_a, mut mt_b) = test_mt_context(8);
|
||||
|
||||
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
|
||||
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
|
||||
|
||||
let delta_a = Delta::new(Block::random(&mut rng));
|
||||
let delta_b = Delta::new(Block::random(&mut rng));
|
||||
|
||||
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
|
||||
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
|
||||
|
||||
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
|
||||
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
|
||||
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
|
||||
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
|
||||
|
||||
let mpc_a = Arc::new(Mutex::new(Garbler::new(
|
||||
DerandCOTSender::new(rcot_send_a.clone()),
|
||||
rand::rng().random(),
|
||||
delta_a,
|
||||
)));
|
||||
let mpc_b = Arc::new(Mutex::new(Evaluator::new(DerandCOTReceiver::new(
|
||||
rcot_recv_b.clone(),
|
||||
))));
|
||||
|
||||
let leader = MpcTlsLeader::new(
|
||||
config.clone(),
|
||||
ctx_a,
|
||||
mpc_a,
|
||||
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
|
||||
rcot_recv_a,
|
||||
);
|
||||
|
||||
let follower = MpcTlsFollower::new(
|
||||
config,
|
||||
ctx_b,
|
||||
mpc_b,
|
||||
rcot_send_b,
|
||||
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
|
||||
);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
@@ -5,7 +5,7 @@ description = "A TLS backend trait for TLSNotary"
|
||||
keywords = ["tls", "mpc", "2pc"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
|
||||
39
crates/tls/client-async/Cargo.toml
Normal file
39
crates/tls/client-async/Cargo.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[package]
|
||||
name = "tlsn-tls-client-async"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "An async TLS client for TLSNotary"
|
||||
keywords = ["tls", "mpc", "2pc", "client", "async"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
name = "tls_client_async"
|
||||
|
||||
[features]
|
||||
default = ["tracing"]
|
||||
tracing = ["dep:tracing"]
|
||||
|
||||
[dependencies]
|
||||
tlsn-tls-client = { workspace = true }
|
||||
|
||||
bytes = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["io", "compat"] }
|
||||
tracing = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tls-server-fixture = { workspace = true }
|
||||
|
||||
http-body-util = { workspace = true }
|
||||
hyper = { workspace = true, features = ["client", "http1"] }
|
||||
hyper-util = { workspace = true, features = ["full"] }
|
||||
rstest = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
|
||||
rustls-webpki = { workspace = true }
|
||||
rustls-pki-types = { workspace = true }
|
||||
89
crates/tls/client-async/src/conn.rs
Normal file
89
crates/tls/client-async/src/conn.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use bytes::Bytes;
|
||||
use futures::{
|
||||
channel::mpsc::{Receiver, SendError, Sender},
|
||||
sink::SinkMapErr,
|
||||
AsyncRead, AsyncWrite, SinkExt,
|
||||
};
|
||||
use std::{
|
||||
io::{Error as IoError, ErrorKind as IoErrorKind},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio_util::{
|
||||
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
|
||||
io::{CopyToBytes, SinkWriter, StreamReader},
|
||||
};
|
||||
|
||||
type CompatSinkWriter =
|
||||
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
|
||||
|
||||
/// A TLS connection to a server.
|
||||
///
|
||||
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
|
||||
/// communicate with a server using TLS.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This connection is closed on a best-effort basis if this is dropped. To
|
||||
/// ensure a clean close, you should call
|
||||
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
|
||||
/// connection.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsConnection {
|
||||
/// The data to be transmitted to the server is sent to this sink.
|
||||
tx_sender: CompatSinkWriter,
|
||||
/// The data to be received from the server is received from this stream.
|
||||
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
|
||||
}
|
||||
|
||||
impl TlsConnection {
|
||||
/// Creates a new TLS connection.
|
||||
pub(crate) fn new(
|
||||
tx_sender: Sender<Bytes>,
|
||||
rx_receiver: Receiver<Result<Bytes, IoError>>,
|
||||
) -> Self {
|
||||
fn convert_error(err: SendError) -> IoError {
|
||||
if err.is_disconnected() {
|
||||
IoErrorKind::BrokenPipe.into()
|
||||
} else {
|
||||
IoErrorKind::WouldBlock.into()
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
tx_sender: SinkWriter::new(CopyToBytes::new(
|
||||
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
|
||||
))
|
||||
.compat_write(),
|
||||
rx_receiver: StreamReader::new(rx_receiver).compat(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TlsConnection {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize, IoError>> {
|
||||
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsConnection {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_close(cx)
|
||||
}
|
||||
}
|
||||
269
crates/tls/client-async/src/lib.rs
Normal file
269
crates/tls/client-async/src/lib.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
//! Provides a TLS client which exposes an async socket.
|
||||
//!
|
||||
//! This library provides the [bind_client] function which attaches a TLS client
|
||||
//! to a socket connection and then exposes a [TlsConnection] object, which
|
||||
//! provides an async socket API for reading and writing cleartext. The TLS
|
||||
//! client will then automatically encrypt and decrypt traffic and forward that
|
||||
//! to the provided socket.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
mod conn;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{
|
||||
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
|
||||
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
|
||||
};
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use tracing::{debug, debug_span, error, trace, warn, Instrument};
|
||||
|
||||
use tls_client::ClientConnection;
|
||||
|
||||
pub use conn::TlsConnection;
|
||||
|
||||
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
|
||||
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
|
||||
|
||||
/// An error that can occur during a TLS connection.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectionError {
|
||||
#[error(transparent)]
|
||||
TlsError(#[from] tls_client::Error),
|
||||
#[error(transparent)]
|
||||
IOError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// Closed connection data.
|
||||
#[derive(Debug)]
|
||||
pub struct ClosedConnection {
|
||||
/// The connection for the client
|
||||
pub client: ClientConnection,
|
||||
/// Sent plaintext bytes
|
||||
pub sent: Vec<u8>,
|
||||
/// Received plaintext bytes
|
||||
pub recv: Vec<u8>,
|
||||
}
|
||||
|
||||
/// A future which runs the TLS connection to completion.
|
||||
///
|
||||
/// This future must be polled in order for the connection to make progress.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct ConnectionFuture {
|
||||
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
|
||||
}
|
||||
|
||||
impl Future for ConnectionFuture {
|
||||
type Output = Result<ClosedConnection, ConnectionError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.fut.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Binds a client connection to the provided socket.
|
||||
///
|
||||
/// Returns a connection handle and a future which runs the connection to
|
||||
/// completion.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Any connection errors that occur will be returned from the future, not
|
||||
/// [`TlsConnection`].
|
||||
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
|
||||
socket: T,
|
||||
mut client: ClientConnection,
|
||||
) -> (TlsConnection, ConnectionFuture) {
|
||||
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
|
||||
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
|
||||
|
||||
let conn = TlsConnection::new(tx_sender, rx_receiver);
|
||||
|
||||
let fut = async move {
|
||||
client.start().await?;
|
||||
let mut notify = client.get_notify().await?;
|
||||
|
||||
let (mut server_rx, mut server_tx) = socket.split();
|
||||
|
||||
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
|
||||
let mut rx_buf = [0u8; RX_BUF_SIZE];
|
||||
|
||||
let mut handshake_done = false;
|
||||
let mut client_closed = false;
|
||||
let mut server_closed = false;
|
||||
|
||||
let mut sent = Vec::with_capacity(1024);
|
||||
let mut recv = Vec::with_capacity(1024);
|
||||
|
||||
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
|
||||
// We don't start writing application data until the handshake is complete.
|
||||
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
|
||||
|
||||
// Runs both the tx and rx halves of the connection to completion.
|
||||
// This loop does not terminate until the *SERVER* closes the connection and
|
||||
// we've processed all received data. If an error occurs, the `TlsConnection`
|
||||
// channels will be closed and the error will be returned from this future.
|
||||
'conn: loop {
|
||||
// Write all pending TLS data to the server.
|
||||
if client.wants_write() && !client_closed {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("client wants to write");
|
||||
while client.wants_write() {
|
||||
let _sent = client.write_tls_async(&mut server_tx).await?;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("sent {} tls bytes to server", _sent);
|
||||
}
|
||||
server_tx.flush().await?;
|
||||
}
|
||||
|
||||
// Forward received plaintext to `TlsConnection`.
|
||||
while !client.plaintext_is_empty() {
|
||||
let read = client.read_plaintext(&mut rx_buf)?;
|
||||
recv.extend(&rx_buf[..read]);
|
||||
// Ignore if the receiver has hung up.
|
||||
_ = rx_sender
|
||||
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
|
||||
.await;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("forwarded {} plaintext bytes to conn", read);
|
||||
}
|
||||
|
||||
if !client.is_handshaking() && !handshake_done {
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("handshake complete");
|
||||
handshake_done = true;
|
||||
// Start reading application data that needs to be transmitted from the
|
||||
// `TlsConnection`.
|
||||
tx_recv_fut = tx_receiver.next().fuse();
|
||||
}
|
||||
|
||||
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
|
||||
break 'conn;
|
||||
}
|
||||
|
||||
select_biased! {
|
||||
// Reads TLS data from the server and writes it into the client.
|
||||
received = &mut rx_tls_fut => {
|
||||
let received = received?;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("received {} tls bytes from server", received);
|
||||
|
||||
// Loop until we've processed all the data we received in this read.
|
||||
// Note that we must make one iteration even if `received == 0`.
|
||||
let mut processed = 0;
|
||||
let mut reader = rx_tls_buf[..received].reader();
|
||||
loop {
|
||||
processed += client.read_tls(&mut reader)?;
|
||||
client.process_new_packets().await?;
|
||||
|
||||
debug_assert!(processed <= received);
|
||||
if processed >= received {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("processed {} tls bytes from server", processed);
|
||||
|
||||
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
|
||||
// has closed the socket.
|
||||
if received == 0 {
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("server closed connection");
|
||||
server_closed = true;
|
||||
client.server_closed().await?;
|
||||
// Do not read from the socket again.
|
||||
rx_tls_fut = Fuse::terminated();
|
||||
} else {
|
||||
// Reset the read future so next iteration we can read again.
|
||||
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
|
||||
}
|
||||
}
|
||||
// If we receive None from `TlsConnection`, it has closed, so we
|
||||
// send a close_notify to the server.
|
||||
data = &mut tx_recv_fut => {
|
||||
if let Some(data) = data {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("writing {} plaintext bytes to client", data.len());
|
||||
|
||||
sent.extend(&data);
|
||||
client
|
||||
.write_all_plaintext(&data)
|
||||
.await?;
|
||||
|
||||
tx_recv_fut = tx_receiver.next().fuse();
|
||||
} else {
|
||||
if !server_closed {
|
||||
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
|
||||
#[cfg(feature = "tracing")]
|
||||
warn!("failed to send close_notify to server: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
client_closed = true;
|
||||
|
||||
tx_recv_fut = Fuse::terminated();
|
||||
}
|
||||
}
|
||||
// Waits for a notification from the backend that it is ready to decrypt data.
|
||||
_ = &mut notify => {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("backend is ready to decrypt");
|
||||
|
||||
client.process_new_packets().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("client shutdown");
|
||||
|
||||
_ = server_tx.close().await;
|
||||
tx_receiver.close();
|
||||
rx_sender.close_channel();
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!(
|
||||
"server close notify: {}, sent: {}, recv: {}",
|
||||
client.received_close_notify(),
|
||||
sent.len(),
|
||||
recv.len()
|
||||
);
|
||||
|
||||
Ok(ClosedConnection { client, sent, recv })
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
let fut = fut.instrument(debug_span!("tls_connection"));
|
||||
|
||||
let fut = ConnectionFuture { fut: Box::pin(fut) };
|
||||
|
||||
(conn, fut)
|
||||
}
|
||||
|
||||
async fn send_close_notify(
|
||||
client: &mut ClientConnection,
|
||||
server_tx: &mut (impl AsyncWrite + Unpin),
|
||||
) -> Result<(), ConnectionError> {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("sending close_notify to server");
|
||||
client.send_close_notify().await?;
|
||||
client.process_new_packets().await?;
|
||||
|
||||
// Flush all remaining plaintext
|
||||
while client.wants_write() {
|
||||
client.write_tls_async(server_tx).await?;
|
||||
}
|
||||
server_tx.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
438
crates/tls/client-async/tests/test.rs
Normal file
438
crates/tls/client-async/tests/test.rs
Normal file
@@ -0,0 +1,438 @@
|
||||
use std::{str, sync::Arc};
|
||||
|
||||
use core::future::Future;
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use http_body_util::{BodyExt as _, Full};
|
||||
use hyper::{body::Bytes, Request, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rstest::{fixture, rstest};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
|
||||
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
|
||||
use tls_server_fixture::{
|
||||
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
|
||||
SERVER_DOMAIN,
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
|
||||
|
||||
// An established client TLS connection
|
||||
struct TlsFixture {
|
||||
client_tls_conn: TlsConnection,
|
||||
// a handle that must be `.await`ed to get the result of a TLS connection
|
||||
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
|
||||
}
|
||||
|
||||
// Sets up a TLS connection between client and server and sends a hello message
|
||||
#[fixture]
|
||||
async fn set_up_tls() -> TlsFixture {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
|
||||
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
|
||||
|
||||
let mut root_store = tls_client::RootCertStore::empty();
|
||||
root_store
|
||||
.roots
|
||||
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(RustCryptoBackend::new()),
|
||||
ServerName::try_from(SERVER_DOMAIN).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
|
||||
|
||||
let closed_tls_task = tokio::spawn(tls_fut);
|
||||
|
||||
client_tls_conn
|
||||
.write_all(&pad("expecting you to send back hello".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// give the server some time to respond
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
|
||||
let mut plaintext = vec![0u8; 320];
|
||||
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
|
||||
let s = str::from_utf8(&plaintext[0..n]).unwrap();
|
||||
|
||||
assert_eq!(s, "hello");
|
||||
|
||||
TlsFixture {
|
||||
client_tls_conn,
|
||||
closed_tls_task,
|
||||
}
|
||||
}
|
||||
|
||||
// Expect the async tls client wrapped in `hyper::client` to make a successful
|
||||
// request and receive the expected response
|
||||
#[tokio::test]
|
||||
async fn test_hyper_ok() {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
|
||||
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
|
||||
|
||||
let mut root_store = tls_client::RootCertStore::empty();
|
||||
root_store
|
||||
.roots
|
||||
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(RustCryptoBackend::new()),
|
||||
ServerName::try_from(SERVER_DOMAIN).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
|
||||
|
||||
let closed_tls_task = tokio::spawn(tls_fut);
|
||||
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::spawn(connection);
|
||||
|
||||
let request = Request::builder()
|
||||
.uri(format!("https://{SERVER_DOMAIN}/echo"))
|
||||
.header("Host", SERVER_DOMAIN)
|
||||
.header("Connection", "close")
|
||||
.method("POST")
|
||||
.body(Full::<Bytes>::new("hello".into()))
|
||||
.unwrap();
|
||||
|
||||
let response = request_sender.send_request(request).await.unwrap();
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// Process the response body
|
||||
response.into_body().collect().await.unwrap().to_bytes();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server responds to the client's
|
||||
// close_notify but doesn't close the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server responds to the client's
|
||||
// close_notify AND also closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us AND close the socket
|
||||
// after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server sends close_notify first
|
||||
// but doesn't close the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time for server's close_notify to arrive
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server sends close_notify first
|
||||
// AND also closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_close_notify_and_socket_close(
|
||||
set_up_tls: impl Future<Output = TlsFixture>,
|
||||
) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time for server's close_notify to arrive
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect to be able to read the data after server closes the socket abruptly
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
..
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a hello message
|
||||
client_tls_conn
|
||||
.write_all(&pad("send a hello message".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// try to read some more data
|
||||
let mut buf = vec![0u8; 10];
|
||||
let n = client_tls_conn.read(&mut buf).await.unwrap();
|
||||
|
||||
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
|
||||
}
|
||||
|
||||
// Expect there to be no error when server DOES NOT send close_notify but just
|
||||
// closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(!closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect to register a delay when the server delays closing the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
client_tls_conn
|
||||
.write_all(&pad("must_delay_when_closing".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
use std::time::Instant;
|
||||
let now = Instant::now();
|
||||
// this will resolve when the server stops delaying closing the socket
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
let elapsed = now.elapsed();
|
||||
|
||||
// the elapsed time must be roughly equal to the server's delay
|
||||
// (give or take timing variations)
|
||||
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
|
||||
|
||||
assert!(!closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a corrupted message
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send a corrupted message
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_corrupted_message".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"received corrupt message"
|
||||
);
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a TLS record with a bad MAC
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a TLS record with a bad MAC
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_record_with_bad_mac".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"backend error: Decryption error: \"aead::Error\""
|
||||
);
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a fatal alert
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a TLS record with a bad MAC
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_alert".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"received fatal alert: BadRecordMac"
|
||||
);
|
||||
}
|
||||
|
||||
// Expect an error when trying to write data to a connection which server closed
|
||||
// abruptly
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
..
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// try to send some more data
|
||||
let res = client_tls_conn
|
||||
.write_all(&pad("more data".to_string()))
|
||||
.await;
|
||||
|
||||
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
|
||||
}
|
||||
|
||||
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
|
||||
fn pad(s: String) -> Vec<u8> {
|
||||
assert!(s.len() <= APP_RECORD_LENGTH);
|
||||
let mut buf = vec![0u8; APP_RECORD_LENGTH];
|
||||
buf[..s.len()].copy_from_slice(s.as_bytes());
|
||||
buf
|
||||
}
|
||||
@@ -5,7 +5,7 @@ description = "A TLS client for TLSNotary"
|
||||
keywords = ["tls", "mpc", "2pc", "client", "sync"]
|
||||
categories = ["cryptography"]
|
||||
license = "Apache-2.0 OR ISC OR MIT"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
autobenches = false
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use super::{Backend, BackendError};
|
||||
use crate::{DecryptMode, EncryptMode, Error};
|
||||
#[allow(deprecated)]
|
||||
use aes_gcm::{
|
||||
aead::{generic_array::GenericArray, Aead, NewAead, Payload},
|
||||
Aes128Gcm,
|
||||
@@ -508,7 +507,6 @@ impl Encrypter {
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(&self.write_iv);
|
||||
nonce[4..].copy_from_slice(explicit_nonce);
|
||||
#[allow(deprecated)]
|
||||
let nonce = GenericArray::from_slice(&nonce);
|
||||
let cipher = Aes128Gcm::new_from_slice(&self.write_key).unwrap();
|
||||
// ciphertext will have the MAC appended
|
||||
@@ -570,7 +568,6 @@ impl Decrypter {
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(&self.write_iv);
|
||||
nonce[4..].copy_from_slice(&m.payload.0[0..8]);
|
||||
#[allow(deprecated)]
|
||||
let nonce = GenericArray::from_slice(&nonce);
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, aes_payload)
|
||||
|
||||
@@ -457,9 +457,6 @@ impl ConnectionCommon {
|
||||
return Err(Error::CorruptMessage);
|
||||
}
|
||||
|
||||
// Process outgoing plaintext buffer and encrypt messages.
|
||||
self.flush_plaintext().await?;
|
||||
|
||||
// Process new messages.
|
||||
while let Some(msg) = self.message_deframer.frames.pop_front() {
|
||||
// If we're not decrypting yet, we process it immediately. Otherwise it will be
|
||||
@@ -511,22 +508,25 @@ impl ConnectionCommon {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Writes plaintext `buf` into an internal buffer. May not fully process the
|
||||
/// whole buffer and returns the processed length.
|
||||
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
if buf.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return Ok(0);
|
||||
/// Write buffer into connection.
|
||||
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
if let Ok(st) = &mut self.state {
|
||||
st.perhaps_write_key_update(&mut self.common_state).await;
|
||||
}
|
||||
|
||||
let len = self.sendable_plaintext.append_limited_copy(buf);
|
||||
Ok(len)
|
||||
self.common_state.send_some_plaintext(buf).await
|
||||
}
|
||||
|
||||
/// Writes the entire plaintext `buf` into an internal buffer.
|
||||
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
|
||||
self.sendable_plaintext.append(buf.to_vec());
|
||||
Ok(())
|
||||
/// Write entire buffer into connection.
|
||||
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
let mut pos = 0;
|
||||
while pos < buf.len() {
|
||||
pos += self.write_plaintext(&buf[pos..]).await?;
|
||||
}
|
||||
self.backend.flush().await?;
|
||||
while let Some(msg) = self.backend.next_outgoing().await? {
|
||||
self.queue_tls_message(msg);
|
||||
}
|
||||
Ok(pos)
|
||||
}
|
||||
|
||||
/// Read TLS content from `rd`. This method does internal
|
||||
@@ -690,11 +690,6 @@ impl CommonState {
|
||||
self.received_plaintext.is_empty()
|
||||
}
|
||||
|
||||
/// Returns true if the buffer for sendable plaintext is full.
|
||||
pub fn sendable_plaintext_is_full(&self) -> bool {
|
||||
self.sendable_plaintext.is_full()
|
||||
}
|
||||
|
||||
/// Returns true if the connection is currently performing the TLS
|
||||
/// handshake.
|
||||
///
|
||||
@@ -787,6 +782,15 @@ impl CommonState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send plaintext application data, fragmenting and
|
||||
/// encrypting it as it goes out.
|
||||
///
|
||||
/// If internal buffers are too small, this function will not accept
|
||||
/// all the data.
|
||||
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
|
||||
self.send_plain(data, Limit::Yes).await
|
||||
}
|
||||
|
||||
// Changing the keys must not span any fragmented handshake
|
||||
// messages. Otherwise the defragmented messages will have
|
||||
// been protected with two different record layer protections,
|
||||
@@ -927,6 +931,32 @@ impl CommonState {
|
||||
self.sendable_tls.write_to_async(wr).await
|
||||
}
|
||||
|
||||
/// Encrypt and send some plaintext `data`. `limit` controls
|
||||
/// whether the per-connection buffer limits apply.
|
||||
///
|
||||
/// Returns the number of bytes written from `data`: this might
|
||||
/// be less than `data.len()` if buffer limits were exceeded.
|
||||
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
|
||||
if !self.may_send_application_data {
|
||||
// If we haven't completed handshaking, buffer
|
||||
// plaintext to send once we do.
|
||||
let len = match limit {
|
||||
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
|
||||
Limit::No => self.sendable_plaintext.append(data.to_vec()),
|
||||
};
|
||||
return Ok(len);
|
||||
}
|
||||
|
||||
debug_assert!(self.record_layer.is_encrypting());
|
||||
|
||||
if data.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, limit).await
|
||||
}
|
||||
|
||||
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
|
||||
self.may_send_application_data = true;
|
||||
self.flush_plaintext().await
|
||||
@@ -982,14 +1012,15 @@ impl CommonState {
|
||||
self.sendable_tls.set_limit(limit);
|
||||
}
|
||||
|
||||
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
|
||||
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
|
||||
/// Send any buffered plaintext. Plaintext is buffered if
|
||||
/// written during handshake.
|
||||
async fn flush_plaintext(&mut self) -> Result<(), Error> {
|
||||
if !self.may_send_application_data {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while let Some(buf) = self.sendable_plaintext.pop() {
|
||||
self.send_appdata_encrypt(&buf, Limit::No).await?;
|
||||
self.send_plain(&buf, Limit::No).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -35,15 +35,6 @@ impl ChunkVecBuffer {
|
||||
self.chunks.is_empty()
|
||||
}
|
||||
|
||||
/// If the buffer has reached limit.
|
||||
pub(crate) fn is_full(&self) -> bool {
|
||||
if let Some(limit) = self.limit {
|
||||
self.len() >= limit
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// How many bytes we're storing
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
let mut len = 0;
|
||||
|
||||
@@ -247,8 +247,7 @@ async fn servered_client_data_sent() {
|
||||
let (mut client, mut server) =
|
||||
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
|
||||
|
||||
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
send(&mut client, &mut server);
|
||||
@@ -287,7 +286,7 @@ async fn servered_both_data_sent() {
|
||||
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
|
||||
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
@@ -433,7 +432,7 @@ async fn server_close_notify() {
|
||||
|
||||
// check that alerts don't overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
server.send_close_notify();
|
||||
|
||||
receive(&mut server, &mut client);
|
||||
@@ -461,8 +460,7 @@ async fn client_close_notify() {
|
||||
|
||||
// check that alerts don't overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
client.send_close_notify().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
@@ -489,7 +487,7 @@ async fn server_closes_uncleanly() {
|
||||
|
||||
// check that unclean EOF reporting does not overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
|
||||
receive(&mut server, &mut client);
|
||||
transfer_eof(&mut client);
|
||||
@@ -520,7 +518,7 @@ async fn client_closes_uncleanly() {
|
||||
|
||||
// check that unclean EOF reporting does not overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
@@ -902,9 +900,20 @@ async fn client_respects_buffer_limit_pre_handshake() {
|
||||
|
||||
client.set_buffer_limit(Some(32));
|
||||
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
20
|
||||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
12
|
||||
);
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
send(&mut client, &mut server);
|
||||
@@ -944,9 +953,20 @@ async fn client_respects_buffer_limit_post_handshake() {
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
client.set_buffer_limit(Some(48));
|
||||
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
|
||||
client.flush_plaintext().await.unwrap();
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
20
|
||||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
6
|
||||
);
|
||||
|
||||
send(&mut client, &mut server);
|
||||
server.process_new_packets().unwrap();
|
||||
@@ -1191,8 +1211,14 @@ async fn client_complete_io_for_write() {
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
client.write_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_plaintext(b"01234567890123456789").unwrap();
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let (rdlen, wrlen) = client
|
||||
@@ -1324,8 +1350,7 @@ async fn server_stream_read() {
|
||||
for kt in ALL_KEY_TYPES.iter() {
|
||||
let (mut client, mut server) = make_pair(*kt).await;
|
||||
|
||||
client.write_all_plaintext(b"world").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
client.write_all_plaintext(b"world").await.unwrap();
|
||||
|
||||
{
|
||||
let mut pipe = ClientSession::new(&mut client);
|
||||
@@ -1341,8 +1366,7 @@ async fn server_streamowned_read() {
|
||||
for kt in ALL_KEY_TYPES.iter() {
|
||||
let (mut client, server) = make_pair(*kt).await;
|
||||
|
||||
client.write_all_plaintext(b"world").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
client.write_all_plaintext(b"world").await.unwrap();
|
||||
|
||||
{
|
||||
let pipe = ClientSession::new(&mut client);
|
||||
@@ -1361,9 +1385,7 @@ async fn server_streamowned_read() {
|
||||
// errkind: io::ErrorKind::ConnectionAborted,
|
||||
// after: 0,
|
||||
// };
|
||||
// client.write_all_plaintext(b"hello").unwrap();
|
||||
// client.process_new_packets().await.unwrap();
|
||||
//
|
||||
// client.write_all_plaintext(b"hello").await.unwrap();
|
||||
// let mut client_stream = Stream::new(&mut client, &mut pipe);
|
||||
// let rc = client_stream.write(b"world");
|
||||
// assert!(rc.is_err());
|
||||
@@ -1380,9 +1402,7 @@ async fn server_streamowned_read() {
|
||||
// errkind: io::ErrorKind::ConnectionAborted,
|
||||
// after: 1,
|
||||
// };
|
||||
// client.write_all_plaintext(b"hello").unwrap();
|
||||
// client.process_new_packets().await.unwrap();
|
||||
//
|
||||
// client.write_all_plaintext(b"hello").await.unwrap();
|
||||
// let mut client_stream = Stream::new(&mut client, &mut pipe);
|
||||
// let rc = client_stream.write(b"world");
|
||||
// assert_eq!(format!("{:?}", rc), "Ok(5)");
|
||||
@@ -1880,9 +1900,14 @@ async fn servered_write_for_client_appdata() {
|
||||
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
@@ -1994,10 +2019,11 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
|
||||
async fn servered_write_for_client_handshake() {
|
||||
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
|
||||
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_all_plaintext(b"0123456789").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client.write_all_plaintext(b"0123456789").await.unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Cryptographic operations for the TLSNotary TLS client"
|
||||
keywords = ["tls", "mpc", "2pc"]
|
||||
categories = ["cryptography"]
|
||||
license = "Apache-2.0 OR ISC OR MIT"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
|
||||
@@ -4,7 +4,7 @@ authors = ["TLSNotary Team"]
|
||||
keywords = ["tls", "mpc", "2pc", "prover"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
version = "0.1.0-alpha.13"
|
||||
edition = "2024"
|
||||
|
||||
[lints]
|
||||
@@ -12,7 +12,6 @@ workspace = true
|
||||
|
||||
[features]
|
||||
default = ["rayon"]
|
||||
mozilla-certs = ["tlsn-core/mozilla-certs"]
|
||||
rayon = ["mpz-zk/rayon", "mpz-garble/rayon"]
|
||||
web = ["dep:web-spawn"]
|
||||
|
||||
@@ -21,18 +20,18 @@ tlsn-attestation = { workspace = true }
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-deap = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
tlsn-mpc-tls = { workspace = true }
|
||||
tlsn-cipher = { workspace = true }
|
||||
|
||||
futures-plex = { workspace = true }
|
||||
serio = { workspace = true, features = ["compat"] }
|
||||
uid-mux = { workspace = true, features = ["serio"] }
|
||||
web-spawn = { workspace = true, optional = true }
|
||||
|
||||
mpz-circuits = { workspace = true, features = ["aes"] }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-garble-core = { workspace = true }
|
||||
mpz-hash = { workspace = true }
|
||||
@@ -41,10 +40,10 @@ mpz-ole = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-zk = { workspace = true }
|
||||
mpz-ideal-vm = { workspace = true }
|
||||
|
||||
aes = { workspace = true }
|
||||
ctr = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
opaque-debug = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
@@ -57,7 +56,6 @@ serde = { workspace = true, features = ["derive"] }
|
||||
ghash = { workspace = true }
|
||||
semver = { workspace = true, features = ["serde"] }
|
||||
once_cell = { workspace = true }
|
||||
pin-project-lite = { workspace = true }
|
||||
rangeset = { workspace = true }
|
||||
webpki-roots = { workspace = true }
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
fn main() {
|
||||
println!("cargo:rustc-check-cfg=cfg(tlsn_insecure)");
|
||||
}
|
||||
368
crates/tlsn/src/config.rs
Normal file
368
crates/tlsn/src/config.rs
Normal file
@@ -0,0 +1,368 @@
|
||||
//! TLSNotary protocol config and config utilities.
|
||||
use core::fmt;
|
||||
use once_cell::sync::Lazy;
|
||||
use semver::Version;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
pub use tlsn_core::webpki::{CertificateDer, PrivateKeyDer, RootCertStore};
|
||||
|
||||
// Default is 32 bytes to decrypt the TLS protocol messages.
|
||||
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
|
||||
// Default maximum number of TLS records to allow.
|
||||
//
|
||||
// This would allow for up to 50Mb upload from prover to verifier.
|
||||
const DEFAULT_RECORDS_LIMIT: usize = 256;
|
||||
|
||||
// Current version that is running.
|
||||
static VERSION: Lazy<Version> = Lazy::new(|| {
|
||||
Version::parse(env!("CARGO_PKG_VERSION"))
|
||||
.map_err(|err| ProtocolConfigError::new(ErrorKind::Version, err))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
/// Protocol configuration to be set up initially by prover and verifier.
|
||||
#[derive(derive_builder::Builder, Clone, Debug, Deserialize, Serialize)]
|
||||
#[builder(build_fn(validate = "Self::validate"))]
|
||||
pub struct ProtocolConfig {
|
||||
/// Maximum number of bytes that can be sent.
|
||||
max_sent_data: usize,
|
||||
/// Maximum number of application data records that can be sent.
|
||||
#[builder(setter(strip_option), default)]
|
||||
max_sent_records: Option<usize>,
|
||||
/// Maximum number of bytes that can be decrypted online, i.e. while the
|
||||
/// MPC-TLS connection is active.
|
||||
#[builder(default = "DEFAULT_MAX_RECV_ONLINE")]
|
||||
max_recv_data_online: usize,
|
||||
/// Maximum number of bytes that can be received.
|
||||
max_recv_data: usize,
|
||||
/// Maximum number of received application data records that can be
|
||||
/// decrypted online, i.e. while the MPC-TLS connection is active.
|
||||
#[builder(setter(strip_option), default)]
|
||||
max_recv_records_online: Option<usize>,
|
||||
/// Whether the `deferred decryption` feature is toggled on from the start
|
||||
/// of the MPC-TLS connection.
|
||||
#[builder(default = "true")]
|
||||
defer_decryption_from_start: bool,
|
||||
/// Network settings.
|
||||
#[builder(default)]
|
||||
network: NetworkSetting,
|
||||
/// Version that is being run by prover/verifier.
|
||||
#[builder(setter(skip), default = "VERSION.clone()")]
|
||||
version: Version,
|
||||
}
|
||||
|
||||
impl ProtocolConfigBuilder {
|
||||
fn validate(&self) -> Result<(), String> {
|
||||
if self.max_recv_data_online > self.max_recv_data {
|
||||
return Err(
|
||||
"max_recv_data_online must be smaller or equal to max_recv_data".to_string(),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolConfig {
|
||||
/// Creates a new builder for `ProtocolConfig`.
|
||||
pub fn builder() -> ProtocolConfigBuilder {
|
||||
ProtocolConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be sent.
|
||||
pub fn max_sent_data(&self) -> usize {
|
||||
self.max_sent_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of application data records that can
|
||||
/// be sent.
|
||||
pub fn max_sent_records(&self) -> Option<usize> {
|
||||
self.max_sent_records
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be decrypted online.
|
||||
pub fn max_recv_data_online(&self) -> usize {
|
||||
self.max_recv_data_online
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be received.
|
||||
pub fn max_recv_data(&self) -> usize {
|
||||
self.max_recv_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of received application data records that
|
||||
/// can be decrypted online.
|
||||
pub fn max_recv_records_online(&self) -> Option<usize> {
|
||||
self.max_recv_records_online
|
||||
}
|
||||
|
||||
/// Returns whether the `deferred decryption` feature is toggled on from the
|
||||
/// start of the MPC-TLS connection.
|
||||
pub fn defer_decryption_from_start(&self) -> bool {
|
||||
self.defer_decryption_from_start
|
||||
}
|
||||
|
||||
/// Returns the network settings.
|
||||
pub fn network(&self) -> NetworkSetting {
|
||||
self.network
|
||||
}
|
||||
}
|
||||
|
||||
/// Protocol configuration validator used by checker (i.e. verifier) to perform
|
||||
/// compatibility check with the peer's (i.e. the prover's) configuration.
|
||||
#[derive(derive_builder::Builder, Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ProtocolConfigValidator {
|
||||
/// Maximum number of bytes that can be sent.
|
||||
max_sent_data: usize,
|
||||
/// Maximum number of application data records that can be sent.
|
||||
#[builder(default = "DEFAULT_RECORDS_LIMIT")]
|
||||
max_sent_records: usize,
|
||||
/// Maximum number of bytes that can be received.
|
||||
max_recv_data: usize,
|
||||
/// Maximum number of application data records that can be received online.
|
||||
#[builder(default = "DEFAULT_RECORDS_LIMIT")]
|
||||
max_recv_records_online: usize,
|
||||
/// Version that is being run by checker.
|
||||
#[builder(setter(skip), default = "VERSION.clone()")]
|
||||
version: Version,
|
||||
}
|
||||
|
||||
impl ProtocolConfigValidator {
|
||||
/// Creates a new builder for `ProtocolConfigValidator`.
|
||||
pub fn builder() -> ProtocolConfigValidatorBuilder {
|
||||
ProtocolConfigValidatorBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be sent.
|
||||
pub fn max_sent_data(&self) -> usize {
|
||||
self.max_sent_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of application data records that can
|
||||
/// be sent.
|
||||
pub fn max_sent_records(&self) -> usize {
|
||||
self.max_sent_records
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be received.
|
||||
pub fn max_recv_data(&self) -> usize {
|
||||
self.max_recv_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of application data records that can
|
||||
/// be received online.
|
||||
pub fn max_recv_records_online(&self) -> usize {
|
||||
self.max_recv_records_online
|
||||
}
|
||||
|
||||
/// Performs compatibility check of the protocol configuration between
|
||||
/// prover and verifier.
|
||||
pub fn validate(&self, config: &ProtocolConfig) -> Result<(), ProtocolConfigError> {
|
||||
self.check_max_transcript_size(config.max_sent_data, config.max_recv_data)?;
|
||||
self.check_max_records(config.max_sent_records, config.max_recv_records_online)?;
|
||||
self.check_version(&config.version)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Checks if both the sent and recv data are within limits.
|
||||
fn check_max_transcript_size(
|
||||
&self,
|
||||
max_sent_data: usize,
|
||||
max_recv_data: usize,
|
||||
) -> Result<(), ProtocolConfigError> {
|
||||
if max_sent_data > self.max_sent_data {
|
||||
return Err(ProtocolConfigError::max_transcript_size(format!(
|
||||
"max_sent_data {:?} is greater than the configured limit {:?}",
|
||||
max_sent_data, self.max_sent_data,
|
||||
)));
|
||||
}
|
||||
|
||||
if max_recv_data > self.max_recv_data {
|
||||
return Err(ProtocolConfigError::max_transcript_size(format!(
|
||||
"max_recv_data {:?} is greater than the configured limit {:?}",
|
||||
max_recv_data, self.max_recv_data,
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_max_records(
|
||||
&self,
|
||||
max_sent_records: Option<usize>,
|
||||
max_recv_records_online: Option<usize>,
|
||||
) -> Result<(), ProtocolConfigError> {
|
||||
if let Some(max_sent_records) = max_sent_records
|
||||
&& max_sent_records > self.max_sent_records
|
||||
{
|
||||
return Err(ProtocolConfigError::max_record_count(format!(
|
||||
"max_sent_records {} is greater than the configured limit {}",
|
||||
max_sent_records, self.max_sent_records,
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(max_recv_records_online) = max_recv_records_online
|
||||
&& max_recv_records_online > self.max_recv_records_online
|
||||
{
|
||||
return Err(ProtocolConfigError::max_record_count(format!(
|
||||
"max_recv_records_online {} is greater than the configured limit {}",
|
||||
max_recv_records_online, self.max_recv_records_online,
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Checks if both versions are the same (might support check for different but
|
||||
// compatible versions in the future).
|
||||
fn check_version(&self, peer_version: &Version) -> Result<(), ProtocolConfigError> {
|
||||
if *peer_version != self.version {
|
||||
return Err(ProtocolConfigError::version(format!(
|
||||
"prover's version {:?} is different from verifier's version {:?}",
|
||||
peer_version, self.version
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings for the network environment.
|
||||
///
|
||||
/// Provides optimization options to adapt the protocol to different network
|
||||
/// situations.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum NetworkSetting {
|
||||
/// Reduces network round-trips at the expense of consuming more network
|
||||
/// bandwidth.
|
||||
Bandwidth,
|
||||
/// Reduces network bandwidth utilization at the expense of more network
|
||||
/// round-trips.
|
||||
Latency,
|
||||
}
|
||||
|
||||
impl Default for NetworkSetting {
|
||||
fn default() -> Self {
|
||||
Self::Latency
|
||||
}
|
||||
}
|
||||
|
||||
/// A ProtocolConfig error.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub struct ProtocolConfigError {
|
||||
kind: ErrorKind,
|
||||
#[source]
|
||||
source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl ProtocolConfigError {
|
||||
fn new<E>(kind: ErrorKind, source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync>>,
|
||||
{
|
||||
Self {
|
||||
kind,
|
||||
source: Some(source.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn max_transcript_size(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::MaxTranscriptSize,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn max_record_count(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::MaxRecordCount,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn version(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Version,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ProtocolConfigError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.kind {
|
||||
ErrorKind::MaxTranscriptSize => write!(f, "max transcript size exceeded")?,
|
||||
ErrorKind::MaxRecordCount => write!(f, "max record count exceeded")?,
|
||||
ErrorKind::Version => write!(f, "version error")?,
|
||||
}
|
||||
|
||||
if let Some(ref source) = self.source {
|
||||
write!(f, " caused by: {source}")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ErrorKind {
|
||||
MaxTranscriptSize,
|
||||
MaxRecordCount,
|
||||
Version,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use rstest::{fixture, rstest};
|
||||
|
||||
const TEST_MAX_SENT_LIMIT: usize = 1 << 12;
|
||||
const TEST_MAX_RECV_LIMIT: usize = 1 << 14;
|
||||
|
||||
#[fixture]
|
||||
#[once]
|
||||
fn config_validator() -> ProtocolConfigValidator {
|
||||
ProtocolConfigValidator::builder()
|
||||
.max_sent_data(TEST_MAX_SENT_LIMIT)
|
||||
.max_recv_data(TEST_MAX_RECV_LIMIT)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::same_max_sent_recv_data(TEST_MAX_SENT_LIMIT, TEST_MAX_RECV_LIMIT)]
|
||||
#[case::smaller_max_sent_data(1 << 11, TEST_MAX_RECV_LIMIT)]
|
||||
#[case::smaller_max_recv_data(TEST_MAX_SENT_LIMIT, 1 << 13)]
|
||||
#[case::smaller_max_sent_recv_data(1 << 7, 1 << 9)]
|
||||
fn test_check_success(
|
||||
config_validator: &ProtocolConfigValidator,
|
||||
#[case] max_sent_data: usize,
|
||||
#[case] max_recv_data: usize,
|
||||
) {
|
||||
let peer_config = ProtocolConfig::builder()
|
||||
.max_sent_data(max_sent_data)
|
||||
.max_recv_data(max_recv_data)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert!(config_validator.validate(&peer_config).is_ok())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::bigger_max_sent_data(1 << 13, TEST_MAX_RECV_LIMIT)]
|
||||
#[case::bigger_max_recv_data(1 << 10, 1 << 16)]
|
||||
#[case::bigger_max_sent_recv_data(1 << 14, 1 << 21)]
|
||||
fn test_check_fail(
|
||||
config_validator: &ProtocolConfigValidator,
|
||||
#[case] max_sent_data: usize,
|
||||
#[case] max_recv_data: usize,
|
||||
) {
|
||||
let peer_config = ProtocolConfig::builder()
|
||||
.max_sent_data(max_sent_data)
|
||||
.max_recv_data(max_recv_data)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert!(config_validator.validate(&peer_config).is_err())
|
||||
}
|
||||
}
|
||||
21
crates/tlsn/src/context.rs
Normal file
21
crates/tlsn/src/context.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
//! Execution context.
|
||||
|
||||
use mpz_common::context::Multithread;
|
||||
|
||||
use crate::mux::MuxControl;
|
||||
|
||||
/// Maximum concurrency for multi-threaded context.
|
||||
pub(crate) const MAX_CONCURRENCY: usize = 8;
|
||||
|
||||
/// Builds a multi-threaded context with the given muxer.
|
||||
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
|
||||
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
|
||||
|
||||
#[cfg(all(feature = "web", target_arch = "wasm32"))]
|
||||
let builder = builder.spawn_handler(|f| {
|
||||
let _ = web_spawn::spawn(f);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
builder.build().unwrap()
|
||||
}
|
||||
@@ -4,30 +4,18 @@
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod config;
|
||||
pub(crate) mod context;
|
||||
pub(crate) mod ghash;
|
||||
pub(crate) mod map;
|
||||
pub(crate) mod mpz;
|
||||
pub(crate) mod msg;
|
||||
pub(crate) mod mux;
|
||||
pub mod prover;
|
||||
pub(crate) mod tag;
|
||||
pub(crate) mod transcript_internal;
|
||||
pub(crate) mod utils;
|
||||
pub mod verifier;
|
||||
|
||||
pub use tlsn_attestation as attestation;
|
||||
pub use tlsn_core::{config, connection, hash, transcript, webpki};
|
||||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use semver::Version;
|
||||
|
||||
// Package version.
|
||||
pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
|
||||
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
|
||||
});
|
||||
|
||||
const BUF_CAP: usize = 16 * 1024;
|
||||
pub use tlsn_core::{connection, hash, transcript};
|
||||
|
||||
/// The party's role in the TLSN protocol.
|
||||
///
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_memory_core::{Vector, binary::U8};
|
||||
use rangeset::set::RangeSet;
|
||||
use rangeset::RangeSet;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct RangeMap<T> {
|
||||
@@ -77,7 +77,7 @@ where
|
||||
|
||||
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
|
||||
let mut map = Vec::new();
|
||||
for idx in idx.iter() {
|
||||
for idx in idx.iter_ranges() {
|
||||
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
|
||||
Ok(i) => i,
|
||||
Err(0) => return None,
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mpc_tls::{MpcTlsFollower, MpcTlsLeader, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use mpz_core::Block;
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_garble_core::Delta;
|
||||
use mpz_memory_core::{
|
||||
Vector,
|
||||
binary::U8,
|
||||
correlated::{Key, Mac},
|
||||
};
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
use mpz_ot::cot::{DerandCOTReceiver, DerandCOTSender};
|
||||
use mpz_ot::{
|
||||
chou_orlandi as co, ferret, kos,
|
||||
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
|
||||
};
|
||||
use mpz_zk::{Prover, Verifier};
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
use rand::Rng;
|
||||
use tlsn_core::config::tls_commit::mpc::{MpcTlsConfig, NetworkSetting};
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::transcript_internal::commit::encoding::{KeyStore, MacStore};
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
pub(crate) type ProverMpc =
|
||||
Garbler<DerandCOTSender<SharedRCOTSender<kos::Sender<co::Receiver>, Block>>>;
|
||||
#[cfg(tlsn_insecure)]
|
||||
pub(crate) type ProverMpc = mpz_ideal_vm::IdealVm;
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
pub(crate) type ProverZk =
|
||||
Prover<SharedRCOTReceiver<ferret::Receiver<kos::Receiver<co::Sender>>, bool, Block>>;
|
||||
#[cfg(tlsn_insecure)]
|
||||
pub(crate) type ProverZk = mpz_ideal_vm::IdealVm;
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
pub(crate) type VerifierMpc =
|
||||
Evaluator<DerandCOTReceiver<SharedRCOTReceiver<kos::Receiver<co::Sender>, bool, Block>>>;
|
||||
#[cfg(tlsn_insecure)]
|
||||
pub(crate) type VerifierMpc = mpz_ideal_vm::IdealVm;
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
pub(crate) type VerifierZk =
|
||||
Verifier<SharedRCOTSender<ferret::Sender<kos::Sender<co::Receiver>>, Block>>;
|
||||
#[cfg(tlsn_insecure)]
|
||||
pub(crate) type VerifierZk = mpz_ideal_vm::IdealVm;
|
||||
|
||||
pub(crate) struct ProverDeps {
|
||||
pub(crate) vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
pub(crate) mpc_tls: MpcTlsLeader,
|
||||
}
|
||||
|
||||
pub(crate) fn build_prover_deps(config: MpcTlsConfig, ctx: Context) -> ProverDeps {
|
||||
let mut rng = rand::rng();
|
||||
let delta = Delta::new(Block::random(&mut rng));
|
||||
|
||||
let base_ot_send = co::Sender::default();
|
||||
let base_ot_recv = co::Receiver::default();
|
||||
let rcot_send = kos::Sender::new(
|
||||
kos::SenderConfig::default(),
|
||||
delta.into_inner(),
|
||||
base_ot_recv,
|
||||
);
|
||||
let rcot_recv = kos::Receiver::new(kos::ReceiverConfig::default(), base_ot_send);
|
||||
let rcot_recv = ferret::Receiver::new(
|
||||
ferret::FerretConfig::builder()
|
||||
.lpn_type(ferret::LpnType::Regular)
|
||||
.build()
|
||||
.expect("ferret config is valid"),
|
||||
Block::random(&mut rng),
|
||||
rcot_recv,
|
||||
);
|
||||
|
||||
let rcot_send = SharedRCOTSender::new(rcot_send);
|
||||
let rcot_recv = SharedRCOTReceiver::new(rcot_recv);
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
let mpc = ProverMpc::new(DerandCOTSender::new(rcot_send.clone()), rng.random(), delta);
|
||||
#[cfg(tlsn_insecure)]
|
||||
let mpc = mpz_ideal_vm::IdealVm::new();
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
let zk = ProverZk::new(Default::default(), rcot_recv.clone());
|
||||
#[cfg(tlsn_insecure)]
|
||||
let zk = mpz_ideal_vm::IdealVm::new();
|
||||
|
||||
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
|
||||
let mpc_tls = MpcTlsLeader::new(
|
||||
build_mpc_tls_config(config),
|
||||
ctx,
|
||||
vm.clone(),
|
||||
(rcot_send.clone(), rcot_send.clone(), rcot_send),
|
||||
rcot_recv,
|
||||
);
|
||||
|
||||
ProverDeps { vm, mpc_tls }
|
||||
}
|
||||
|
||||
pub(crate) struct VerifierDeps {
|
||||
pub(crate) vm: Arc<Mutex<Deap<VerifierMpc, VerifierZk>>>,
|
||||
pub(crate) mpc_tls: MpcTlsFollower,
|
||||
}
|
||||
|
||||
pub(crate) fn build_verifier_deps(config: MpcTlsConfig, ctx: Context) -> VerifierDeps {
|
||||
let mut rng = rand::rng();
|
||||
|
||||
let delta = Delta::random(&mut rng);
|
||||
let base_ot_send = co::Sender::default();
|
||||
let base_ot_recv = co::Receiver::default();
|
||||
let rcot_send = kos::Sender::new(
|
||||
kos::SenderConfig::default(),
|
||||
delta.into_inner(),
|
||||
base_ot_recv,
|
||||
);
|
||||
let rcot_send = ferret::Sender::new(
|
||||
ferret::FerretConfig::builder()
|
||||
.lpn_type(ferret::LpnType::Regular)
|
||||
.build()
|
||||
.expect("ferret config is valid"),
|
||||
Block::random(&mut rng),
|
||||
rcot_send,
|
||||
);
|
||||
let rcot_recv = kos::Receiver::new(kos::ReceiverConfig::default(), base_ot_send);
|
||||
|
||||
let rcot_send = SharedRCOTSender::new(rcot_send);
|
||||
let rcot_recv = SharedRCOTReceiver::new(rcot_recv);
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
let mpc = VerifierMpc::new(DerandCOTReceiver::new(rcot_recv.clone()));
|
||||
#[cfg(tlsn_insecure)]
|
||||
let mpc = mpz_ideal_vm::IdealVm::new();
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
let zk = VerifierZk::new(Default::default(), delta, rcot_send.clone());
|
||||
#[cfg(tlsn_insecure)]
|
||||
let zk = mpz_ideal_vm::IdealVm::new();
|
||||
|
||||
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Follower, mpc, zk)));
|
||||
let mpc_tls = MpcTlsFollower::new(
|
||||
build_mpc_tls_config(config),
|
||||
ctx,
|
||||
vm.clone(),
|
||||
rcot_send,
|
||||
(rcot_recv.clone(), rcot_recv.clone(), rcot_recv),
|
||||
);
|
||||
|
||||
VerifierDeps { vm, mpc_tls }
|
||||
}
|
||||
|
||||
fn build_mpc_tls_config(config: MpcTlsConfig) -> mpc_tls::Config {
|
||||
let mut builder = mpc_tls::Config::builder();
|
||||
|
||||
builder
|
||||
.defer_decryption(config.defer_decryption_from_start())
|
||||
.max_sent(config.max_sent_data())
|
||||
.max_recv_online(config.max_recv_data_online())
|
||||
.max_recv(config.max_recv_data());
|
||||
|
||||
if let Some(max_sent_records) = config.max_sent_records() {
|
||||
builder.max_sent_records(max_sent_records);
|
||||
}
|
||||
|
||||
if let Some(max_recv_records_online) = config.max_recv_records_online() {
|
||||
builder.max_recv_records_online(max_recv_records_online);
|
||||
}
|
||||
|
||||
if let NetworkSetting::Latency = config.network() {
|
||||
builder.low_bandwidth();
|
||||
}
|
||||
|
||||
builder.build().unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) {
|
||||
keys.client_write_key = vm
|
||||
.translate(keys.client_write_key)
|
||||
.expect("VM memory should be consistent");
|
||||
keys.client_write_iv = vm
|
||||
.translate(keys.client_write_iv)
|
||||
.expect("VM memory should be consistent");
|
||||
keys.server_write_key = vm
|
||||
.translate(keys.server_write_key)
|
||||
.expect("VM memory should be consistent");
|
||||
keys.server_write_iv = vm
|
||||
.translate(keys.server_write_iv)
|
||||
.expect("VM memory should be consistent");
|
||||
keys.server_write_mac_key = vm
|
||||
.translate(keys.server_write_mac_key)
|
||||
.expect("VM memory should be consistent");
|
||||
}
|
||||
|
||||
impl<T> KeyStore for Verifier<T> {
|
||||
fn delta(&self) -> &Delta {
|
||||
self.delta()
|
||||
}
|
||||
|
||||
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
|
||||
self.get_keys(data).ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MacStore for Prover<T> {
|
||||
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
|
||||
self.get_macs(data).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(tlsn_insecure)]
|
||||
mod insecure {
|
||||
use super::*;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
|
||||
impl KeyStore for IdealVm {
|
||||
fn delta(&self) -> &Delta {
|
||||
unimplemented!("encodings not supported in insecure mode")
|
||||
}
|
||||
|
||||
fn get_keys(&self, _data: Vector<U8>) -> Option<&[Key]> {
|
||||
unimplemented!("encodings not supported in insecure mode")
|
||||
}
|
||||
}
|
||||
|
||||
impl MacStore for IdealVm {
|
||||
fn get_macs(&self, _data: Vector<U8>) -> Option<&[Mac]> {
|
||||
unimplemented!("encodings not supported in insecure mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
use semver::Version;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use tlsn_core::{
|
||||
config::{prove::ProveRequest, tls_commit::TlsCommitRequest},
|
||||
connection::{HandshakeData, ServerName},
|
||||
transcript::PartialTranscript,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct TlsCommitRequestMsg {
|
||||
pub(crate) request: TlsCommitRequest,
|
||||
pub(crate) version: Version,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct ProveRequestMsg {
|
||||
pub(crate) request: ProveRequest,
|
||||
pub(crate) handshake: Option<(ServerName, HandshakeData)>,
|
||||
pub(crate) transcript: Option<PartialTranscript>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct Response {
|
||||
pub(crate) result: Result<(), RejectionReason>,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub(crate) fn ok() -> Self {
|
||||
Self { result: Ok(()) }
|
||||
}
|
||||
|
||||
pub(crate) fn err(msg: Option<impl Into<String>>) -> Self {
|
||||
Self {
|
||||
result: Err(RejectionReason(msg.map(Into::into))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct RejectionReason(Option<String>);
|
||||
|
||||
impl From<RejectionReason> for crate::prover::ProverError {
|
||||
fn from(value: RejectionReason) -> Self {
|
||||
if let Some(msg) = value.0 {
|
||||
crate::prover::ProverError::config(format!("verifier rejected with reason: {msg}"))
|
||||
} else {
|
||||
crate::prover::ProverError::config("verifier rejected without providing a reason")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,50 +1,54 @@
|
||||
//! Prover.
|
||||
|
||||
mod client;
|
||||
mod conn;
|
||||
mod control;
|
||||
mod config;
|
||||
mod error;
|
||||
mod future;
|
||||
mod prove;
|
||||
pub mod state;
|
||||
|
||||
pub use conn::TlsConnection;
|
||||
pub use control::ProverControl;
|
||||
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
|
||||
pub use error::ProverError;
|
||||
pub use tlsn_core::ProverOutput;
|
||||
|
||||
use crate::{
|
||||
BUF_CAP, Role,
|
||||
mpz::{ProverDeps, build_prover_deps, translate_keys},
|
||||
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
|
||||
mux::attach_mux,
|
||||
prover::{
|
||||
client::{MpcTlsClient, TlsOutput},
|
||||
state::ConnectedProj,
|
||||
},
|
||||
utils::{CopyIo, await_with_copy_io, build_mt_context},
|
||||
};
|
||||
|
||||
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, FutureExt, TryFutureExt, ready};
|
||||
pub use future::ProverFuture;
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
|
||||
|
||||
use mpz_common::Context;
|
||||
use mpz_core::Block;
|
||||
use mpz_garble_core::Delta;
|
||||
use mpz_vm_core::prelude::*;
|
||||
use mpz_zk::ProverConfig as ZkProverConfig;
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
use crate::{Role, context::build_mt_context, mux::attach_mux, tag::verify_tags};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
|
||||
use rand::Rng;
|
||||
use serio::SinkExt;
|
||||
use std::sync::Arc;
|
||||
use tls_client::{ClientConnection, ServerName as TlsServerName};
|
||||
use tls_client_async::{TlsConnection, bind_client};
|
||||
use tlsn_core::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
prover::ProverConfig,
|
||||
tls::TlsClientConfig,
|
||||
tls_commit::{TlsCommitConfig, TlsCommitProtocolConfig},
|
||||
},
|
||||
connection::{HandshakeData, ServerName},
|
||||
connection::ServerName,
|
||||
transcript::{TlsTranscript, Transcript},
|
||||
};
|
||||
use tracing::{Span, debug, info_span, instrument};
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use tracing::{Instrument, Span, debug, info, info_span, instrument};
|
||||
|
||||
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
|
||||
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
|
||||
mpz_core::Block,
|
||||
>;
|
||||
pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
|
||||
mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
|
||||
bool,
|
||||
mpz_core::Block,
|
||||
>;
|
||||
pub(crate) type Mpc =
|
||||
mpz_garble::protocol::semihonest::Garbler<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
|
||||
pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
|
||||
|
||||
/// A prover instance.
|
||||
#[derive(Debug)]
|
||||
@@ -69,66 +73,34 @@ impl Prover<state::Initialized> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts the TLS commitment protocol.
|
||||
/// Sets up the prover.
|
||||
///
|
||||
/// This initiates the TLS commitment protocol, including performing any
|
||||
/// necessary preprocessing operations.
|
||||
/// This performs all MPC setup prior to establishing the connection to the
|
||||
/// application server.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The TLS commitment configuration.
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
self,
|
||||
config: TlsCommitConfig,
|
||||
verifier_io: S,
|
||||
) -> Result<Prover<state::CommitAccepted>, ProverError> {
|
||||
let (duplex_a, mut duplex_b) = futures_plex::duplex(BUF_CAP);
|
||||
let fut = Box::pin(self.commit_inner(config, duplex_a).fuse());
|
||||
let mut prover = await_with_copy_io(fut, verifier_io, &mut duplex_b).await?;
|
||||
|
||||
prover.state.verifier_io = Some(duplex_b);
|
||||
Ok(prover)
|
||||
}
|
||||
|
||||
/// * `socket` - The socket to the TLS verifier.
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn commit_inner<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
self,
|
||||
config: TlsCommitConfig,
|
||||
verifier_io: S,
|
||||
) -> Result<Prover<state::CommitAccepted>, ProverError> {
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(verifier_io, Role::Prover);
|
||||
socket: S,
|
||||
) -> Result<Prover<state::Setup>, ProverError> {
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
|
||||
let mut mt = build_mt_context(mux_ctrl.clone());
|
||||
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
|
||||
|
||||
// Sends protocol configuration to verifier for compatibility check.
|
||||
mux_fut
|
||||
.poll_with(async {
|
||||
ctx.io_mut()
|
||||
.send(TlsCommitRequestMsg {
|
||||
request: config.to_request(),
|
||||
version: crate::VERSION.clone(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
ctx.io_mut()
|
||||
.expect_next::<Response>()
|
||||
.await?
|
||||
.result
|
||||
.map_err(ProverError::from)
|
||||
})
|
||||
.poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
|
||||
.await?;
|
||||
|
||||
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = config.protocol().clone() else {
|
||||
unreachable!("only MPC TLS is supported");
|
||||
};
|
||||
|
||||
let ProverDeps { vm, mut mpc_tls } = build_prover_deps(mpc_tls_config, ctx);
|
||||
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
|
||||
|
||||
// Allocate resources for MPC-TLS in the VM.
|
||||
let mut keys = mpc_tls.alloc()?;
|
||||
let vm_lock = vm.try_lock().expect("VM is not locked");
|
||||
translate_keys(&mut keys, &vm_lock);
|
||||
translate_keys(&mut keys, &vm_lock)?;
|
||||
drop(vm_lock);
|
||||
|
||||
debug!("setting up mpc-tls");
|
||||
@@ -140,8 +112,7 @@ impl Prover<state::Initialized> {
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::CommitAccepted {
|
||||
verifier_io: None,
|
||||
state: state::Setup {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
@@ -152,39 +123,37 @@ impl Prover<state::Initialized> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Prover<state::CommitAccepted> {
|
||||
/// Sets up the prover with the client configuration.
|
||||
impl Prover<state::Setup> {
|
||||
/// Connects to the server using the provided socket.
|
||||
///
|
||||
/// Returns a set up prover, and a [`TlsConnection`] which can be used to
|
||||
/// read and write bytes from/to the server.
|
||||
/// Returns a handle to the TLS connection, a future which returns the
|
||||
/// prover once the connection is closed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The TLS client configuration.
|
||||
/// * `socket` - The socket to the server.
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
pub fn setup(
|
||||
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
self,
|
||||
config: TlsClientConfig,
|
||||
) -> Result<(TlsConnection, Prover<state::Setup>), ProverError> {
|
||||
let state::CommitAccepted {
|
||||
verifier_io,
|
||||
socket: S,
|
||||
) -> Result<(TlsConnection, ProverFuture), ProverError> {
|
||||
let state::Setup {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mut mux_fut,
|
||||
mpc_tls,
|
||||
keys,
|
||||
vm,
|
||||
..
|
||||
} = self.state;
|
||||
|
||||
let decrypt = mpc_tls.is_decrypting();
|
||||
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
|
||||
|
||||
let ServerName::Dns(server_name) = config.server_name();
|
||||
let ServerName::Dns(server_name) = self.config.server_name();
|
||||
let server_name =
|
||||
TlsServerName::try_from(server_name.as_ref()).expect("name was validated");
|
||||
|
||||
let root_store = tls_client::RootCertStore {
|
||||
roots: config
|
||||
.root_store()
|
||||
let root_store = if let Some(root_store) = self.config.tls_config().root_store() {
|
||||
let roots = root_store
|
||||
.roots
|
||||
.iter()
|
||||
.map(|cert| {
|
||||
@@ -193,15 +162,20 @@ impl Prover<state::CommitAccepted> {
|
||||
.map(|anchor| anchor.to_owned())
|
||||
.map_err(ProverError::config)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
tls_client::RootCertStore { roots }
|
||||
} else {
|
||||
tls_client::RootCertStore {
|
||||
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
|
||||
}
|
||||
};
|
||||
|
||||
let rustls_config = tls_client::ClientConfig::builder()
|
||||
let config = tls_client::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store);
|
||||
|
||||
let rustls_config = if let Some((cert, key)) = config.client_auth() {
|
||||
rustls_config
|
||||
let config = if let Some((cert, key)) = self.config.tls_config().client_auth() {
|
||||
config
|
||||
.with_single_cert(
|
||||
cert.iter()
|
||||
.map(|cert| tls_client::Certificate(cert.0.clone()))
|
||||
@@ -210,306 +184,101 @@ impl Prover<state::CommitAccepted> {
|
||||
)
|
||||
.map_err(ProverError::config)?
|
||||
} else {
|
||||
rustls_config.with_no_client_auth()
|
||||
config.with_no_client_auth()
|
||||
};
|
||||
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(rustls_config),
|
||||
Box::new(mpc_ctrl.clone()),
|
||||
server_name,
|
||||
)
|
||||
.map_err(ProverError::config)?;
|
||||
let client =
|
||||
ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
|
||||
.map_err(ProverError::config)?;
|
||||
|
||||
let span = self.span.clone();
|
||||
let (conn, conn_fut) = bind_client(socket, client);
|
||||
|
||||
let mpc_tls = MpcTlsClient::new(
|
||||
Box::new(mpc_fut.map_err(ProverError::from)),
|
||||
keys,
|
||||
vm,
|
||||
span,
|
||||
mpc_ctrl,
|
||||
client,
|
||||
decrypt,
|
||||
);
|
||||
let fut = Box::pin({
|
||||
let span = self.span.clone();
|
||||
let mpc_ctrl = mpc_ctrl.clone();
|
||||
async move {
|
||||
let conn_fut = async {
|
||||
mux_fut
|
||||
.poll_with(conn_fut.map_err(ProverError::from))
|
||||
.await?;
|
||||
|
||||
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Setup {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
server_name: config.server_name().clone(),
|
||||
tls_client: Box::new(mpc_tls),
|
||||
client_io: duplex_a,
|
||||
verifier_io,
|
||||
},
|
||||
};
|
||||
mpc_ctrl.stop().await?;
|
||||
|
||||
let conn = TlsConnection::new(duplex_b);
|
||||
Ok((conn, prover))
|
||||
}
|
||||
}
|
||||
Ok::<_, ProverError>(())
|
||||
};
|
||||
|
||||
impl Prover<state::Setup> {
|
||||
/// Returns a handle to control the prover.
|
||||
pub fn handle(&self) -> ProverControl {
|
||||
let handle = self.state.tls_client.handle();
|
||||
ProverControl { handle }
|
||||
}
|
||||
info!("starting MPC-TLS");
|
||||
|
||||
/// Attaches IO to the prover.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `server_io` - The IO to the server.
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub fn connect<S, T>(self, server_io: S, verifier_io: T) -> Prover<state::Connected<S, T>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
let (client_to_server, server_to_client) = futures_plex::duplex(BUF_CAP);
|
||||
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
|
||||
conn_fut,
|
||||
mpc_fut.in_current_span().map_err(ProverError::from)
|
||||
)?;
|
||||
|
||||
Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Connected {
|
||||
verifier_io: self.state.verifier_io,
|
||||
mux_ctrl: self.state.mux_ctrl,
|
||||
mux_fut: self.state.mux_fut,
|
||||
server_name: self.state.server_name,
|
||||
tls_client: self.state.tls_client,
|
||||
client_io: self.state.client_io,
|
||||
output: None,
|
||||
server_socket: server_io,
|
||||
verifier_socket: verifier_io,
|
||||
tls_client_to_server_buf: client_to_server,
|
||||
server_to_tls_client_buf: server_to_client,
|
||||
client_closed: false,
|
||||
server_closed: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
info!("finished MPC-TLS");
|
||||
|
||||
/// This is a convenience method which attaches IO, runs the prover and
|
||||
/// returns a committed prover together with the IO.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `server_io` - The IO to the server.
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub async fn run<S, T>(
|
||||
self,
|
||||
mut server_io: S,
|
||||
mut verifier_io: T,
|
||||
) -> Result<(Prover<state::Committed>, S, T), ProverError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let mut prover = self.connect(&mut server_io, &mut verifier_io);
|
||||
(&mut prover).await?;
|
||||
{
|
||||
let mut vm = vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
let prover = prover.finish()?;
|
||||
Ok((prover, server_io, verifier_io))
|
||||
}
|
||||
}
|
||||
debug!("finalizing mpc");
|
||||
|
||||
impl<S, T> Future for Prover<state::Connected<S, T>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
type Output = Result<(), ProverError>;
|
||||
// Finalize DEAP.
|
||||
mux_fut
|
||||
.poll_with(vm.finalize(&mut ctx))
|
||||
.await
|
||||
.map_err(ProverError::mpc)?;
|
||||
|
||||
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut state = Pin::new(&mut self.state).project();
|
||||
|
||||
loop {
|
||||
let mut progress = false;
|
||||
|
||||
if state.output.is_none()
|
||||
&& let Poll::Ready(output) = state.tls_client.poll(cx)?
|
||||
{
|
||||
*state.output = Some(output);
|
||||
}
|
||||
|
||||
progress |= Self::io_client_conn(&mut state, cx)?;
|
||||
progress |= Self::io_client_server(&mut state, cx)?;
|
||||
progress |= Self::io_client_verifier(&mut state, cx)?;
|
||||
|
||||
_ = state.mux_fut.poll_unpin(cx)?;
|
||||
|
||||
if *state.server_closed && state.output.is_some() {
|
||||
ready!(state.client_io.poll_close(cx))?;
|
||||
ready!(state.server_socket.poll_close(cx))?;
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
} else if !progress {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Prover<state::Connected<S, T>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
fn io_client_conn(
|
||||
state: &mut ConnectedProj<S, T>,
|
||||
cx: &mut Context,
|
||||
) -> Result<bool, ProverError> {
|
||||
let mut progress = false;
|
||||
|
||||
// tls_conn -> tls_client
|
||||
if state.tls_client.wants_write()
|
||||
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_read(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
|
||||
{
|
||||
if !buf.is_empty() {
|
||||
let write = state.tls_client.write(buf)?;
|
||||
if write > 0 {
|
||||
progress = true;
|
||||
simplex.advance(write);
|
||||
debug!("mpc finalized");
|
||||
}
|
||||
} else if !*state.client_closed && !*state.server_closed {
|
||||
progress = true;
|
||||
*state.client_closed = true;
|
||||
state.tls_client.client_close()?;
|
||||
|
||||
// Pull out ZK VM.
|
||||
let (_, mut vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut vm,
|
||||
(keys.server_write_key, keys.server_write_iv),
|
||||
keys.server_write_mac_key,
|
||||
*tls_transcript.version(),
|
||||
tls_transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
|
||||
.await?;
|
||||
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
.instrument(span)
|
||||
});
|
||||
|
||||
// tls_client -> tls_conn
|
||||
if state.tls_client.wants_read()
|
||||
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_write(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
|
||||
&& let read = state.tls_client.read(buf)?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
simplex.advance_mut(read);
|
||||
}
|
||||
Ok(progress)
|
||||
}
|
||||
|
||||
fn io_client_server(
|
||||
state: &mut ConnectedProj<S, T>,
|
||||
cx: &mut Context,
|
||||
) -> Result<bool, ProverError> {
|
||||
let mut progress = false;
|
||||
|
||||
// server_socket -> buf
|
||||
if let Poll::Ready(write) = state
|
||||
.server_to_tls_client_buf
|
||||
.poll_write_from(cx, state.server_socket.as_mut())?
|
||||
{
|
||||
if write > 0 {
|
||||
progress = true;
|
||||
} else if !*state.server_closed {
|
||||
progress = true;
|
||||
*state.server_closed = true;
|
||||
state.tls_client.server_close()?;
|
||||
}
|
||||
}
|
||||
|
||||
// buf -> tls_client
|
||||
if state.tls_client.wants_read_tls()
|
||||
&& let Poll::Ready(mut simplex) =
|
||||
state.tls_client_to_server_buf.as_mut().poll_lock_read(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
|
||||
&& let read = state.tls_client.read_tls(buf)?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
simplex.advance(read);
|
||||
}
|
||||
|
||||
// tls_client -> buf
|
||||
if state.tls_client.wants_write_tls()
|
||||
&& let Poll::Ready(mut simplex) =
|
||||
state.tls_client_to_server_buf.as_mut().poll_lock_write(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
|
||||
&& let write = state.tls_client.write_tls(buf)?
|
||||
&& write > 0
|
||||
{
|
||||
progress = true;
|
||||
simplex.advance_mut(write);
|
||||
}
|
||||
|
||||
// buf -> server_socket
|
||||
if let Poll::Ready(read) = state
|
||||
.server_to_tls_client_buf
|
||||
.poll_read_to(cx, state.server_socket.as_mut())?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
}
|
||||
|
||||
Ok(progress)
|
||||
}
|
||||
|
||||
fn io_client_verifier(
|
||||
state: &mut ConnectedProj<S, T>,
|
||||
cx: &mut Context,
|
||||
) -> Result<bool, ProverError> {
|
||||
let mut progress = false;
|
||||
|
||||
let verifier_io = Pin::new(
|
||||
(*state.verifier_io)
|
||||
.as_mut()
|
||||
.expect("verifier io should be available"),
|
||||
);
|
||||
|
||||
// mux -> verifier_socket
|
||||
if let Poll::Ready(read) = verifier_io.poll_read_to(cx, state.verifier_socket.as_mut())?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
}
|
||||
|
||||
// verifier_socket -> mux
|
||||
if let Poll::Ready(write) =
|
||||
verifier_io.poll_write_from(cx, state.verifier_socket.as_mut())?
|
||||
&& write > 0
|
||||
{
|
||||
progress = true;
|
||||
}
|
||||
|
||||
Ok(progress)
|
||||
}
|
||||
|
||||
/// Returns a committed prover after the TLS session has completed.
|
||||
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
|
||||
let TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
} = self.state.output.ok_or(ProverError::state(
|
||||
"prover has not yet closed the connection",
|
||||
))?;
|
||||
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
verifier_io: self.state.verifier_io,
|
||||
mux_ctrl: self.state.mux_ctrl,
|
||||
mux_fut: self.state.mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
server_name: self.state.server_name,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
Ok((
|
||||
conn,
|
||||
ProverFuture {
|
||||
fut,
|
||||
ctrl: ProverControl { mpc_ctrl },
|
||||
},
|
||||
};
|
||||
|
||||
Ok(prover)
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -529,107 +298,140 @@ impl Prover<state::Committed> {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The disclosure configuration.
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub async fn prove<S>(
|
||||
&mut self,
|
||||
config: &ProveConfig,
|
||||
verifier_io: S,
|
||||
) -> Result<ProverOutput, ProverError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
let mut duplex = self
|
||||
.state
|
||||
.verifier_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = Box::pin(self.prove_inner(config).fuse());
|
||||
let output = await_with_copy_io(fut, verifier_io, &mut duplex).await?;
|
||||
|
||||
self.state.verifier_io = Some(duplex);
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
async fn prove_inner(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
|
||||
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
|
||||
let state::Committed {
|
||||
mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
server_name,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
..
|
||||
} = &mut self.state;
|
||||
|
||||
let handshake = config.server_identity().then(|| {
|
||||
(
|
||||
server_name.clone(),
|
||||
HandshakeData {
|
||||
certs: tls_transcript
|
||||
.server_cert_chain()
|
||||
.expect("server cert chain is present")
|
||||
.to_vec(),
|
||||
sig: tls_transcript
|
||||
.server_signature()
|
||||
.expect("server signature is present")
|
||||
.clone(),
|
||||
binding: tls_transcript.certificate_binding().clone(),
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
let partial_transcript = config
|
||||
.reveal()
|
||||
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone()));
|
||||
|
||||
let msg = ProveRequestMsg {
|
||||
request: config.to_request(),
|
||||
handshake,
|
||||
transcript: partial_transcript,
|
||||
};
|
||||
|
||||
let output = mux_fut
|
||||
.poll_with(async {
|
||||
ctx.io_mut().send(msg).await.map_err(ProverError::from)?;
|
||||
|
||||
ctx.io_mut().expect_next::<Response>().await?.result?;
|
||||
|
||||
prove::prove(ctx, vm, keys, transcript, tls_transcript, config).await
|
||||
})
|
||||
.poll_with(prove::prove(
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
self.config.server_name(),
|
||||
transcript,
|
||||
tls_transcript,
|
||||
config,
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Closes the connection with the verifier.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn close<S>(mut self, mut verifier_io: S) -> Result<(), ProverError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
pub async fn close(self) -> Result<(), ProverError> {
|
||||
let state::Committed {
|
||||
mux_ctrl, mux_fut, ..
|
||||
} = self.state;
|
||||
|
||||
let mut duplex = self
|
||||
.state
|
||||
.verifier_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
// Wait for the verifier to correctly close the connection.
|
||||
if !mux_fut.is_complete() {
|
||||
mux_ctrl.close();
|
||||
mux_fut.await?;
|
||||
}
|
||||
|
||||
mux_ctrl.close();
|
||||
let copy = CopyIo::new(&mut verifier_io, &mut duplex).map_err(ProverError::from);
|
||||
futures::try_join!(mux_fut.map_err(ProverError::from), copy)?;
|
||||
|
||||
// Wait for the verifier to finish closing.
|
||||
verifier_io.read_exact(&mut [0_u8; 5]).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
|
||||
let mut rng = rand::rng();
|
||||
let delta = Delta::new(Block::random(&mut rng));
|
||||
|
||||
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
|
||||
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
|
||||
let rcot_send = mpz_ot::kos::Sender::new(
|
||||
mpz_ot::kos::SenderConfig::default(),
|
||||
delta.into_inner(),
|
||||
base_ot_recv,
|
||||
);
|
||||
let rcot_recv =
|
||||
mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
|
||||
let rcot_recv = mpz_ot::ferret::Receiver::new(
|
||||
mpz_ot::ferret::FerretConfig::builder()
|
||||
.lpn_type(mpz_ot::ferret::LpnType::Regular)
|
||||
.build()
|
||||
.expect("ferret config is valid"),
|
||||
Block::random(&mut rng),
|
||||
rcot_recv,
|
||||
);
|
||||
|
||||
let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
|
||||
let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
|
||||
|
||||
let mpc = Mpc::new(
|
||||
mpz_ot::cot::DerandCOTSender::new(rcot_send.clone()),
|
||||
rng.random(),
|
||||
delta,
|
||||
);
|
||||
|
||||
let zk = Zk::new(ZkProverConfig::default(), rcot_recv.clone());
|
||||
|
||||
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
|
||||
|
||||
(
|
||||
vm.clone(),
|
||||
MpcTlsLeader::new(
|
||||
config.build_mpc_tls_config(),
|
||||
ctx,
|
||||
vm,
|
||||
(rcot_send.clone(), rcot_send.clone(), rcot_send),
|
||||
rcot_recv,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/// A controller for the prover.
|
||||
#[derive(Clone)]
|
||||
pub struct ProverControl {
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
}
|
||||
|
||||
impl ProverControl {
|
||||
/// Defers decryption of data from the server until the server has closed
|
||||
/// the connection.
|
||||
///
|
||||
/// This is a performance optimization which will significantly reduce the
|
||||
/// amount of upload bandwidth used by the prover.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// * The prover may need to close the connection to the server in order for
|
||||
/// it to close the connection on its end. If neither the prover or server
|
||||
/// close the connection this will cause a deadlock.
|
||||
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
|
||||
self.mpc_ctrl
|
||||
.defer_decryption()
|
||||
.await
|
||||
.map_err(ProverError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Translates VM references to the ZK address space.
|
||||
fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
|
||||
keys.client_write_key = vm
|
||||
.translate(keys.client_write_key)
|
||||
.map_err(ProverError::mpc)?;
|
||||
keys.client_write_iv = vm
|
||||
.translate(keys.client_write_iv)
|
||||
.map_err(ProverError::mpc)?;
|
||||
keys.server_write_key = vm
|
||||
.translate(keys.server_write_key)
|
||||
.map_err(ProverError::mpc)?;
|
||||
keys.server_write_iv = vm
|
||||
.translate(keys.server_write_iv)
|
||||
.map_err(ProverError::mpc)?;
|
||||
keys.server_write_mac_key = vm
|
||||
.translate(keys.server_write_mac_key)
|
||||
.map_err(ProverError::mpc)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
//! Provides a TLS client.
|
||||
|
||||
use crate::{mpz::ProverZk, prover::control::ControlError};
|
||||
use mpc_tls::SessionKeys;
|
||||
use std::{
|
||||
sync::mpsc::{Sender, SyncSender, sync_channel},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tlsn_core::transcript::{TlsTranscript, Transcript};
|
||||
|
||||
mod mpc;
|
||||
|
||||
pub(crate) use mpc::MpcTlsClient;
|
||||
|
||||
/// TLS client for MPC and proxy-based TLS implementations.
|
||||
pub(crate) trait TlsClient {
|
||||
type Error: std::error::Error + Send + Sync + Unpin + 'static;
|
||||
|
||||
/// Returns `true` if the client wants to read TLS data from the server.
|
||||
fn wants_read_tls(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write TLS data to the server.
|
||||
fn wants_write_tls(&self) -> bool;
|
||||
|
||||
/// Reads TLS data from the server.
|
||||
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes TLS data for the server into the provided buffer.
|
||||
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Returns `true` if the client wants to read plaintext data.
|
||||
fn wants_read(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write plaintext data.
|
||||
fn wants_write(&self) -> bool;
|
||||
|
||||
/// Reads plaintext data from the server into the provided buffer.
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes plaintext data to be sent to the server.
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Client closes the connection.
|
||||
fn client_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Server closes the connection.
|
||||
fn server_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Returns a handle to control the client.
|
||||
fn handle(&self) -> ClientHandle;
|
||||
|
||||
/// Polls the client to make progress.
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ClientHandle {
|
||||
sender: Sender<Command>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) enum Command {
|
||||
IsDecrypting(SyncSender<bool>),
|
||||
SetDecrypt(bool),
|
||||
ClientClose,
|
||||
ServerClose,
|
||||
}
|
||||
|
||||
impl ClientHandle {
|
||||
pub(crate) fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
|
||||
self.sender
|
||||
.send(Command::SetDecrypt(enable))
|
||||
.map_err(|_| ControlError)
|
||||
}
|
||||
|
||||
pub(crate) fn is_decrypting(&self) -> bool {
|
||||
let (sender, receiver) = sync_channel(1);
|
||||
let Ok(_) = self.sender.send(Command::IsDecrypting(sender)) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
receiver.recv().unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Output of a TLS session.
|
||||
pub(crate) struct TlsOutput {
|
||||
pub(crate) ctx: mpz_common::Context,
|
||||
pub(crate) vm: ProverZk,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) tls_transcript: TlsTranscript,
|
||||
pub(crate) transcript: Transcript,
|
||||
}
|
||||
@@ -1,503 +0,0 @@
|
||||
//! Implementation of an MPC-TLS client.
|
||||
|
||||
use crate::{
|
||||
mpz::{ProverMpc, ProverZk},
|
||||
prover::{
|
||||
ProverError,
|
||||
client::{ClientHandle, Command, TlsClient, TlsOutput},
|
||||
},
|
||||
tag::verify_tags,
|
||||
};
|
||||
use futures::{Future, FutureExt};
|
||||
use mpc_tls::{LeaderCtrl, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use mpz_vm_core::Execute;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{
|
||||
Arc,
|
||||
mpsc::{Receiver, Sender, channel},
|
||||
},
|
||||
task::Poll,
|
||||
};
|
||||
use tls_client::ClientConnection;
|
||||
use tlsn_core::transcript::TlsTranscript;
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{Span, debug, instrument, trace, warn};
|
||||
|
||||
pub(crate) type MpcFuture =
|
||||
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
|
||||
|
||||
type FinalizeFuture =
|
||||
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
|
||||
|
||||
pub(crate) struct MpcTlsClient {
|
||||
sender: Sender<Command>,
|
||||
state: State,
|
||||
decrypt: bool,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Start {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
receiver: Receiver<Command>,
|
||||
},
|
||||
Active {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
receiver: Receiver<Command>,
|
||||
},
|
||||
Busy {
|
||||
mpc: Pin<MpcFuture>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
receiver: Receiver<Command>,
|
||||
},
|
||||
MpcStop {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
},
|
||||
CloseBusy {
|
||||
mpc: Pin<MpcFuture>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
},
|
||||
Finishing {
|
||||
ctx: Context,
|
||||
transcript: Box<TlsTranscript>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
},
|
||||
Finalizing {
|
||||
fut: Pin<FinalizeFuture>,
|
||||
},
|
||||
Finished,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl MpcTlsClient {
|
||||
pub(crate) fn new(
|
||||
mpc: MpcFuture,
|
||||
keys: SessionKeys,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
span: Span,
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
tls: ClientConnection,
|
||||
decrypt: bool,
|
||||
) -> Self {
|
||||
let inner = InnerState {
|
||||
span,
|
||||
tls,
|
||||
vm,
|
||||
keys,
|
||||
mpc_ctrl,
|
||||
client_closed: false,
|
||||
mpc_stopped: false,
|
||||
};
|
||||
let (sender, receiver) = channel();
|
||||
|
||||
Self {
|
||||
sender,
|
||||
decrypt,
|
||||
state: State::Start {
|
||||
receiver,
|
||||
mpc: Box::into_pin(mpc),
|
||||
inner: Box::new(inner),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
|
||||
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &mut self.state {
|
||||
Some(&mut inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client(&self) -> Option<&ClientConnection> {
|
||||
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &self.state {
|
||||
Some(&inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsClient for MpcTlsClient {
|
||||
type Error = ProverError;
|
||||
|
||||
fn wants_read_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_read()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_write()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_read()
|
||||
{
|
||||
client.read_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_write()
|
||||
{
|
||||
client.write_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_read(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.plaintext_is_empty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.sendable_plaintext_is_full()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.plaintext_is_empty()
|
||||
{
|
||||
client.read_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.sendable_plaintext_is_full()
|
||||
{
|
||||
client.write_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn client_close(&mut self) -> Result<(), Self::Error> {
|
||||
self.sender
|
||||
.send(Command::ClientClose)
|
||||
.map_err(|_| ProverError::state("unable to close connection clientside"))
|
||||
}
|
||||
|
||||
fn server_close(&mut self) -> Result<(), Self::Error> {
|
||||
self.sender
|
||||
.send(Command::ServerClose)
|
||||
.map_err(|_| ProverError::state("unable to close connection serverside"))
|
||||
}
|
||||
|
||||
fn handle(&self) -> ClientHandle {
|
||||
ClientHandle {
|
||||
sender: self.sender.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
|
||||
match std::mem::replace(&mut self.state, State::Error) {
|
||||
State::Start {
|
||||
mpc,
|
||||
inner,
|
||||
receiver,
|
||||
} => {
|
||||
trace!("inner client is starting");
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.start()),
|
||||
receiver,
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Active {
|
||||
mpc,
|
||||
inner,
|
||||
receiver,
|
||||
} => {
|
||||
trace!("inner client is active");
|
||||
|
||||
if !inner.tls.is_handshaking()
|
||||
&& let Ok(cmd) = receiver.try_recv()
|
||||
{
|
||||
match cmd {
|
||||
Command::ClientClose => {
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.client_close()),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
Command::ServerClose => {
|
||||
std::mem::drop(receiver);
|
||||
self.state = State::CloseBusy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.server_close()),
|
||||
};
|
||||
}
|
||||
Command::SetDecrypt(enable) => {
|
||||
self.decrypt = enable;
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.set_decrypt(enable)),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
Command::IsDecrypting(sender) => {
|
||||
_ = sender.send(self.decrypt);
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.run()),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.run()),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Busy {
|
||||
mut mpc,
|
||||
mut fut,
|
||||
receiver,
|
||||
} => {
|
||||
trace!("inner client is busy");
|
||||
|
||||
let mpc_poll = mpc.as_mut().poll(cx)?;
|
||||
|
||||
assert!(
|
||||
matches!(mpc_poll, Poll::Pending),
|
||||
"mpc future should not be finished here"
|
||||
);
|
||||
|
||||
match fut.as_mut().poll(cx)? {
|
||||
Poll::Ready(inner) => {
|
||||
self.state = State::Active {
|
||||
mpc,
|
||||
inner,
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
Poll::Pending => self.state = State::Busy { mpc, fut, receiver },
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
State::MpcStop { mpc, inner } => {
|
||||
trace!("inner client is stopping mpc");
|
||||
self.state = State::CloseBusy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.stop()),
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
State::CloseBusy { mut mpc, mut fut } => {
|
||||
trace!("inner client is busy closing");
|
||||
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
|
||||
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, transcript)),
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
(Poll::Ready(inner), Poll::Pending) => {
|
||||
self.state = State::MpcStop { mpc, inner };
|
||||
Poll::Pending
|
||||
}
|
||||
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Finishing {
|
||||
ctx,
|
||||
transcript: Box::new(transcript),
|
||||
fut,
|
||||
};
|
||||
Poll::Pending
|
||||
}
|
||||
(Poll::Pending, Poll::Pending) => {
|
||||
self.state = State::CloseBusy { mpc, fut };
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
State::Finishing {
|
||||
ctx,
|
||||
transcript,
|
||||
mut fut,
|
||||
} => {
|
||||
trace!("inner client is finishing");
|
||||
if let Poll::Ready(inner) = fut.poll_unpin(cx)? {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, *transcript)),
|
||||
};
|
||||
self.poll(cx)
|
||||
} else {
|
||||
self.state = State::Finishing {
|
||||
ctx,
|
||||
transcript,
|
||||
fut,
|
||||
};
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
|
||||
Poll::Ready(output) => {
|
||||
let (inner, ctx, tls_transcript) = output?;
|
||||
let InnerState { vm, keys, .. } = inner;
|
||||
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
let (_, vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
let output = TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
};
|
||||
|
||||
self.state = State::Finished;
|
||||
Poll::Ready(Ok(output))
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.state = State::Finalizing { fut };
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
State::Finished => Poll::Ready(Err(ProverError::state(
|
||||
"mpc tls client polled again in finished state",
|
||||
))),
|
||||
State::Error => {
|
||||
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerState {
|
||||
span: Span,
|
||||
tls: ClientConnection,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
keys: SessionKeys,
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
client_closed: bool,
|
||||
mpc_stopped: bool,
|
||||
}
|
||||
|
||||
impl InnerState {
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.start().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
|
||||
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
|
||||
self.mpc_ctrl.enable_decryption(enable).await?;
|
||||
self.run().await
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
if !self.client_closed {
|
||||
debug!("sending close notify");
|
||||
if let Err(e) = self.tls.send_close_notify().await {
|
||||
warn!("failed to send close_notify to server: {}", e);
|
||||
}
|
||||
self.client_closed = true;
|
||||
}
|
||||
self.run().await
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
self.tls.server_closed().await?;
|
||||
debug!("closed connection serverside");
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn stop(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
|
||||
self.mpc_ctrl.stop().await?;
|
||||
self.mpc_stopped = true;
|
||||
debug!("stopped mpc");
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn finalize(
|
||||
self,
|
||||
mut ctx: Context,
|
||||
transcript: TlsTranscript,
|
||||
) -> Result<(Self, Context, TlsTranscript), ProverError> {
|
||||
{
|
||||
let mut vm = self.vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
// Finalize DEAP.
|
||||
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
|
||||
|
||||
debug!("mpc finalized");
|
||||
|
||||
// Pull out ZK VM.
|
||||
let mut zk = vm.zk();
|
||||
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut *zk,
|
||||
(self.keys.server_write_key, self.keys.server_write_iv),
|
||||
self.keys.server_write_mac_key,
|
||||
*transcript.version(),
|
||||
transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
debug!("verified tags from server");
|
||||
|
||||
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
|
||||
}
|
||||
|
||||
debug!("MPC-TLS done");
|
||||
Ok((self, ctx, transcript))
|
||||
}
|
||||
}
|
||||
144
crates/tlsn/src/prover/config.rs
Normal file
144
crates/tlsn/src/prover/config.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use mpc_tls::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tlsn_core::{
|
||||
connection::ServerName,
|
||||
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
|
||||
};
|
||||
|
||||
use crate::config::{NetworkSetting, ProtocolConfig};
|
||||
|
||||
/// Configuration for the prover.
|
||||
#[derive(Debug, Clone, derive_builder::Builder, Serialize, Deserialize)]
|
||||
pub struct ProverConfig {
|
||||
/// The server DNS name.
|
||||
#[builder(setter(into))]
|
||||
server_name: ServerName,
|
||||
/// Protocol configuration to be checked with the verifier.
|
||||
protocol_config: ProtocolConfig,
|
||||
/// TLS configuration.
|
||||
#[builder(default)]
|
||||
tls_config: TlsConfig,
|
||||
}
|
||||
|
||||
impl ProverConfig {
|
||||
/// Creates a new builder for `ProverConfig`.
|
||||
pub fn builder() -> ProverConfigBuilder {
|
||||
ProverConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the server DNS name.
|
||||
pub fn server_name(&self) -> &ServerName {
|
||||
&self.server_name
|
||||
}
|
||||
|
||||
/// Returns the protocol configuration.
|
||||
pub fn protocol_config(&self) -> &ProtocolConfig {
|
||||
&self.protocol_config
|
||||
}
|
||||
|
||||
/// Returns the TLS configuration.
|
||||
pub fn tls_config(&self) -> &TlsConfig {
|
||||
&self.tls_config
|
||||
}
|
||||
|
||||
pub(crate) fn build_mpc_tls_config(&self) -> Config {
|
||||
let mut builder = Config::builder();
|
||||
|
||||
builder
|
||||
.defer_decryption(self.protocol_config.defer_decryption_from_start())
|
||||
.max_sent(self.protocol_config.max_sent_data())
|
||||
.max_recv_online(self.protocol_config.max_recv_data_online())
|
||||
.max_recv(self.protocol_config.max_recv_data());
|
||||
|
||||
if let Some(max_sent_records) = self.protocol_config.max_sent_records() {
|
||||
builder.max_sent_records(max_sent_records);
|
||||
}
|
||||
|
||||
if let Some(max_recv_records_online) = self.protocol_config.max_recv_records_online() {
|
||||
builder.max_recv_records_online(max_recv_records_online);
|
||||
}
|
||||
|
||||
if let NetworkSetting::Latency = self.protocol_config.network() {
|
||||
builder.low_bandwidth();
|
||||
}
|
||||
|
||||
builder.build().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the prover's TLS connection.
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsConfig {
|
||||
/// Root certificates.
|
||||
root_store: Option<RootCertStore>,
|
||||
/// Certificate chain and a matching private key for client
|
||||
/// authentication.
|
||||
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
/// Creates a new builder for `TlsConfig`.
|
||||
pub fn builder() -> TlsConfigBuilder {
|
||||
TlsConfigBuilder::default()
|
||||
}
|
||||
|
||||
pub(crate) fn root_store(&self) -> Option<&RootCertStore> {
|
||||
self.root_store.as_ref()
|
||||
}
|
||||
|
||||
/// Returns a certificate chain and a matching private key for client
|
||||
/// authentication.
|
||||
pub fn client_auth(&self) -> &Option<(Vec<CertificateDer>, PrivateKeyDer)> {
|
||||
&self.client_auth
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`TlsConfig`].
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TlsConfigBuilder {
|
||||
root_store: Option<RootCertStore>,
|
||||
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
|
||||
}
|
||||
|
||||
impl TlsConfigBuilder {
|
||||
/// Sets the root certificates to use for verifying the server's
|
||||
/// certificate.
|
||||
pub fn root_store(&mut self, store: RootCertStore) -> &mut Self {
|
||||
self.root_store = Some(store);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a DER-encoded certificate chain and a matching private key for
|
||||
/// client authentication.
|
||||
///
|
||||
/// Often the chain will consist of a single end-entity certificate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cert_key` - A tuple containing the certificate chain and the private
|
||||
/// key.
|
||||
///
|
||||
/// - Each certificate in the chain must be in the X.509 format.
|
||||
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
|
||||
pub fn client_auth(&mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> &mut Self {
|
||||
self.client_auth = Some(cert_key);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the TLS configuration.
|
||||
pub fn build(self) -> Result<TlsConfig, TlsConfigError> {
|
||||
Ok(TlsConfig {
|
||||
root_store: self.root_store,
|
||||
client_auth: self.client_auth,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// TLS configuration error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct TlsConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("tls config error")]
|
||||
enum ErrorRepr {}
|
||||
@@ -1,66 +0,0 @@
|
||||
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use futures_plex::DuplexStream;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// A TLS connection to a server.
|
||||
///
|
||||
/// This type implements [`AsyncRead`] and [`AsyncWrite`] and can be used to
|
||||
/// communicate with a server using TLS.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This connection is closed on a best-effort basis if this is dropped. To
|
||||
/// ensure a clean close, you should call
|
||||
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
|
||||
/// connection.
|
||||
pub struct TlsConnection {
|
||||
duplex: DuplexStream,
|
||||
}
|
||||
|
||||
impl TlsConnection {
|
||||
pub(crate) fn new(duplex: DuplexStream) -> Self {
|
||||
Self { duplex }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TlsConnection {
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = futures::executor::block_on(self.duplex.close()) {
|
||||
tracing::error!("error closing connection: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TlsConnection {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
let duplex = Pin::new(&mut self.duplex);
|
||||
duplex.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsConnection {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
let duplex = Pin::new(&mut self.duplex);
|
||||
duplex.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
let duplex = Pin::new(&mut self.duplex);
|
||||
duplex.poll_close(cx)
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
use crate::prover::client::ClientHandle;
|
||||
|
||||
/// A controller for the prover.
|
||||
///
|
||||
/// Can be used to control the decryption of server traffic.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ProverControl {
|
||||
pub(crate) handle: ClientHandle,
|
||||
}
|
||||
|
||||
impl ProverControl {
|
||||
/// Returns whether the prover is decrypting the server traffic.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.handle.is_decrypting()
|
||||
}
|
||||
|
||||
/// Enables or disables the decryption of server traffic.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - If decryption should be enabled or disabled.
|
||||
pub fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
|
||||
self.handle.enable_decryption(enable)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Unable to send control command to prover.")]
|
||||
pub struct ControlError;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user