Compare commits

..

1 Commits

Author SHA1 Message Date
dan
54b0efcd10 feat(hmac-sha256): tls 1.3 key schedule 2025-09-29 13:57:46 +03:00
159 changed files with 9957 additions and 7351 deletions

View File

@@ -21,8 +21,7 @@ env:
# - https://github.com/privacy-ethereum/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.92.0
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
RUST_VERSION: 1.90.0
jobs:
clippy:
@@ -42,7 +41,7 @@ jobs:
uses: Swatinem/rust-cache@v2.7.7
- name: Clippy
run: cargo clippy --keep-going --all-features --all-targets --locked
run: cargo clippy --keep-going --all-features --all-targets --locked -- -D warnings
fmt:
name: Check formatting

View File

@@ -6,7 +6,7 @@ on:
tag:
description: 'Tag to publish to NPM'
required: true
default: 'v0.1.0-alpha.14-pre'
default: 'v0.1.0-alpha.13-pre'
jobs:
release:

View File

@@ -1,7 +1,6 @@
name: Fast-forward main branch to published release tag
on:
workflow_dispatch:
release:
types: [published]

1655
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -53,7 +53,6 @@ tlsn-formats = { path = "crates/formats" }
tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" }
tlsn-key-exchange = { path = "crates/components/key-exchange" }
tlsn-mpc-tls = { path = "crates/mpc-tls" }
tlsn-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "d9facb6" }
tlsn-server-fixture = { path = "crates/server-fixture/server" }
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
tlsn-tls-backend = { path = "crates/tls/backend" }
@@ -67,26 +66,25 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "d9baf0f" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" }
rangeset = { version = "0.4" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
uid-mux = { version = "0.2" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
aead = { version = "0.4" }
aes = { version = "0.8" }

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-attestation"
version = "0.1.0-alpha.14-pre"
version = "0.1.0-alpha.13-pre"
edition = "2024"
[features]
@@ -9,7 +9,7 @@ fixtures = ["tlsn-core/fixtures", "dep:tlsn-data-fixtures"]
[dependencies]
tlsn-tls-core = { workspace = true }
tlsn-core = { workspace = true, features = ["mozilla-certs"] }
tlsn-core = { workspace = true }
tlsn-data-fixtures = { workspace = true, optional = true }
bcs = { workspace = true }
@@ -27,7 +27,6 @@ alloy-primitives = { version = "1.3.1", default-features = false }
alloy-signer = { version = "1.0", default-features = false }
alloy-signer-local = { version = "1.0", default-features = false }
rand06-compat = { workspace = true }
rangeset = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true }

View File

@@ -243,7 +243,8 @@ mod test {
use rstest::{fixture, rstest};
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
fixtures::{ConnectionFixture, encoding_provider},
hash::Blake3,
transcript::Transcript,
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -274,7 +275,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection,
Blake3::default(),
Vec::new(),
);
let attestation_config = AttestationConfig::builder()
.supported_signature_algs([SignatureAlgId::SECP256R1])
@@ -293,7 +300,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection,
Blake3::default(),
Vec::new(),
);
let attestation_config = AttestationConfig::builder()
.supported_signature_algs([SignatureAlgId::SECP256K1])
@@ -313,7 +326,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection,
Blake3::default(),
Vec::new(),
);
let attestation_builder = Attestation::builder(attestation_config)
.accept_request(request)
@@ -334,8 +353,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let mut attestation_builder = Attestation::builder(attestation_config)
.accept_request(request)
@@ -359,8 +383,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let mut attestation_builder = Attestation::builder(attestation_config)
.accept_request(request)
@@ -393,7 +422,9 @@ mod test {
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
vec![Extension {
id: b"foo".to_vec(),
value: b"bar".to_vec(),
@@ -420,7 +451,9 @@ mod test {
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
vec![Extension {
id: b"foo".to_vec(),
value: b"bar".to_vec(),

View File

@@ -1,29 +1,34 @@
//! Attestation fixtures.
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
transcript::{Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment},
hash::HashAlgorithm,
transcript::{
Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment,
encoding::{EncodingProvider, EncodingTree},
},
};
use crate::{
Attestation, AttestationConfig, CryptoProvider, Extension,
request::{Request, RequestConfig},
signing::{
KeyAlgId, SignatureAlgId, SignatureVerifier, SignatureVerifierProvider, Signer,
SignerProvider,
},
signing::SignatureAlgId,
};
/// A Request fixture used for testing.
#[allow(missing_docs)]
pub struct RequestFixture {
pub encoding_tree: EncodingTree,
pub request: Request,
}
/// Returns a request fixture for testing.
pub fn request_fixture(
transcript: Transcript,
encodings_provider: impl EncodingProvider,
connection: ConnectionFixture,
encoding_hasher: impl HashAlgorithm,
extensions: Vec<Extension>,
) -> RequestFixture {
let provider = CryptoProvider::default();
@@ -43,9 +48,15 @@ pub fn request_fixture(
.unwrap();
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let mut builder = RequestConfig::builder();
// Prover constructs encoding tree.
let encoding_tree = EncodingTree::new(
&encoding_hasher,
transcripts_commitment_config.iter_encoding(),
&encodings_provider,
)
.unwrap();
builder.transcript_commit(transcripts_commitment_config);
let mut builder = RequestConfig::builder();
for extension in extensions {
builder.extension(extension);
@@ -61,7 +72,10 @@ pub fn request_fixture(
let (request, _) = request_builder.build(&provider).unwrap();
RequestFixture { request }
RequestFixture {
encoding_tree,
request,
}
}
/// Returns an attestation fixture for testing.
@@ -88,8 +102,7 @@ pub fn attestation_fixture(
let mut provider = CryptoProvider::default();
match signature_alg {
SignatureAlgId::SECP256K1 => provider.signer.set_secp256k1(&[42u8; 32]).unwrap(),
SignatureAlgId::SECP256K1ETH => provider.signer.set_secp256k1eth(&[43u8; 32]).unwrap(),
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[44u8; 32]).unwrap(),
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[42u8; 32]).unwrap(),
_ => unimplemented!(),
};
@@ -109,68 +122,3 @@ pub fn attestation_fixture(
attestation_builder.build(&provider).unwrap()
}
/// Returns a crypto provider which supports only a custom signature alg.
pub fn custom_provider_fixture() -> CryptoProvider {
const CUSTOM_SIG_ALG_ID: SignatureAlgId = SignatureAlgId::new(128);
// A dummy signer.
struct DummySigner {}
impl Signer for DummySigner {
fn alg_id(&self) -> SignatureAlgId {
CUSTOM_SIG_ALG_ID
}
fn sign(
&self,
msg: &[u8],
) -> Result<crate::signing::Signature, crate::signing::SignatureError> {
Ok(crate::signing::Signature {
alg: CUSTOM_SIG_ALG_ID,
data: msg.to_vec(),
})
}
fn verifying_key(&self) -> crate::signing::VerifyingKey {
crate::signing::VerifyingKey {
alg: KeyAlgId::new(128),
data: vec![1, 2, 3, 4],
}
}
}
// A dummy verifier.
struct DummyVerifier {}
impl SignatureVerifier for DummyVerifier {
fn alg_id(&self) -> SignatureAlgId {
CUSTOM_SIG_ALG_ID
}
fn verify(
&self,
_key: &crate::signing::VerifyingKey,
msg: &[u8],
sig: &[u8],
) -> Result<(), crate::signing::SignatureError> {
if msg == sig {
Ok(())
} else {
Err(crate::signing::SignatureError::from_str(
"invalid signature",
))
}
}
}
let mut provider = CryptoProvider::default();
let mut signer_provider = SignerProvider::default();
signer_provider.set_signer(Box::new(DummySigner {}));
provider.signer = signer_provider;
let mut verifier_provider = SignatureVerifierProvider::empty();
verifier_provider.set_verifier(Box::new(DummyVerifier {}));
provider.signature = verifier_provider;
provider
}

View File

@@ -79,6 +79,8 @@
//!
//! // Specify all the transcript commitments we want to make.
//! builder
//! // Use BLAKE3 for encoding commitments.
//! .encoding_hash_alg(HashAlgId::BLAKE3)
//! // Commit to all sent data.
//! .commit_sent(&(0..sent_len))?
//! // Commit to the first 10 bytes of sent data.
@@ -127,7 +129,7 @@
//!
//! ```no_run
//! # use tlsn_attestation::{Attestation, CryptoProvider, Secrets, presentation::Presentation};
//! # use tlsn_core::transcript::Direction;
//! # use tlsn_core::transcript::{TranscriptCommitmentKind, Direction};
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! # let attestation: Attestation = unimplemented!();
//! # let secrets: Secrets = unimplemented!();
@@ -138,6 +140,8 @@
//! let mut builder = secrets.transcript_proof_builder();
//!
//! builder
//! // Use transcript encoding commitments.
//! .commitment_kinds(&[TranscriptCommitmentKind::Encoding])
//! // Disclose the first 10 bytes of the sent data.
//! .reveal(&(0..10), Direction::Sent)?
//! // Disclose all of the received data.
@@ -297,6 +301,8 @@ pub enum FieldKind {
ServerEphemKey = 0x02,
/// Server identity commitment.
ServerIdentityCommitment = 0x03,
/// Encoding commitment.
EncodingCommitment = 0x04,
/// Plaintext hash commitment.
PlaintextHash = 0x05,
}

View File

@@ -20,10 +20,7 @@ use serde::{Deserialize, Serialize};
use tlsn_core::hash::HashAlgId;
use crate::{
Attestation, CryptoProvider, Extension, connection::ServerCertCommitment,
serialize::CanonicalSerialize, signing::SignatureAlgId,
};
use crate::{Attestation, Extension, connection::ServerCertCommitment, signing::SignatureAlgId};
pub use builder::{RequestBuilder, RequestBuilderError};
pub use config::{RequestConfig, RequestConfigBuilder, RequestConfigBuilderError};
@@ -44,107 +41,51 @@ impl Request {
}
/// Validates the content of the attestation against this request.
pub fn validate(
&self,
attestation: &Attestation,
provider: &CryptoProvider,
) -> Result<(), AttestationValidationError> {
pub fn validate(&self, attestation: &Attestation) -> Result<(), InconsistentAttestation> {
if attestation.signature.alg != self.signature_alg {
return Err(AttestationValidationError::inconsistent(format!(
return Err(InconsistentAttestation(format!(
"signature algorithm: expected {:?}, got {:?}",
self.signature_alg, attestation.signature.alg
)));
}
if attestation.header.root.alg != self.hash_alg {
return Err(AttestationValidationError::inconsistent(format!(
return Err(InconsistentAttestation(format!(
"hash algorithm: expected {:?}, got {:?}",
self.hash_alg, attestation.header.root.alg
)));
}
if attestation.body.cert_commitment() != &self.server_cert_commitment {
return Err(AttestationValidationError::inconsistent(
"server certificate commitment does not match",
return Err(InconsistentAttestation(
"server certificate commitment does not match".to_string(),
));
}
// TODO: improve the O(M*N) complexity of this check.
for extension in &self.extensions {
if !attestation.body.extensions().any(|e| e == extension) {
return Err(AttestationValidationError::inconsistent(
"extension is missing from the attestation",
return Err(InconsistentAttestation(
"extension is missing from the attestation".to_string(),
));
}
}
let verifier = provider
.signature
.get(&attestation.signature.alg)
.map_err(|_| {
AttestationValidationError::provider(format!(
"provider not configured for signature algorithm id {:?}",
attestation.signature.alg,
))
})?;
verifier
.verify(
&attestation.body.verifying_key.data,
&CanonicalSerialize::serialize(&attestation.header),
&attestation.signature.data,
)
.map_err(|_| {
AttestationValidationError::inconsistent("failed to verify the signature")
})?;
Ok(())
}
}
/// Error for [`Request::validate`].
#[derive(Debug, thiserror::Error)]
#[error("attestation validation error: {kind}: {message}")]
pub struct AttestationValidationError {
kind: ErrorKind,
message: String,
}
impl AttestationValidationError {
fn inconsistent(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Inconsistent,
message: msg.into(),
}
}
fn provider(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Provider,
message: msg.into(),
}
}
}
#[derive(Debug)]
enum ErrorKind {
Inconsistent,
Provider,
}
impl std::fmt::Display for ErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorKind::Inconsistent => write!(f, "inconsistent"),
ErrorKind::Provider => write!(f, "provider"),
}
}
}
#[error("inconsistent attestation: {0}")]
pub struct InconsistentAttestation(String);
#[cfg(test)]
mod test {
use tlsn_core::{
connection::TranscriptLength, fixtures::ConnectionFixture, hash::HashAlgId,
connection::TranscriptLength,
fixtures::{ConnectionFixture, encoding_provider},
hash::{Blake3, HashAlgId},
transcript::Transcript,
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -152,8 +93,7 @@ mod test {
use crate::{
CryptoProvider,
connection::ServerCertOpening,
fixtures::{RequestFixture, attestation_fixture, custom_provider_fixture, request_fixture},
request::{AttestationValidationError, ErrorKind},
fixtures::{RequestFixture, attestation_fixture, request_fixture},
signing::SignatureAlgId,
};
@@ -162,15 +102,18 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
let provider = CryptoProvider::default();
assert!(request.validate(&attestation, &provider).is_ok())
assert!(request.validate(&attestation).is_ok())
}
#[test]
@@ -178,17 +121,20 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { mut request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { mut request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
request.signature_alg = SignatureAlgId::SECP256R1;
let provider = CryptoProvider::default();
let res = request.validate(&attestation, &provider);
let res = request.validate(&attestation);
assert!(res.is_err());
}
@@ -197,17 +143,20 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { mut request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { mut request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
request.hash_alg = HashAlgId::SHA256;
let provider = CryptoProvider::default();
let res = request.validate(&attestation, &provider);
let res = request.validate(&attestation);
assert!(res.is_err())
}
@@ -216,8 +165,13 @@ mod test {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { mut request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let RequestFixture { mut request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
@@ -230,52 +184,11 @@ mod test {
});
let opening = ServerCertOpening::new(server_cert_data);
let provider = CryptoProvider::default();
let crypto_provider = CryptoProvider::default();
request.server_cert_commitment =
opening.commit(provider.hash.get(&HashAlgId::BLAKE3).unwrap());
opening.commit(crypto_provider.hash.get(&HashAlgId::BLAKE3).unwrap());
let res = request.validate(&attestation, &provider);
let res = request.validate(&attestation);
assert!(res.is_err())
}
#[test]
fn test_wrong_sig() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let mut attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
// Corrupt the signature.
attestation.signature.data[1] = attestation.signature.data[1].wrapping_add(1);
let provider = CryptoProvider::default();
assert!(request.validate(&attestation, &provider).is_err())
}
#[test]
fn test_wrong_provider() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } =
request_fixture(transcript, connection.clone(), Vec::new());
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
let provider = custom_provider_fixture();
assert!(matches!(
request.validate(&attestation, &provider),
Err(AttestationValidationError {
kind: ErrorKind::Provider,
..
})
))
}
}

View File

@@ -49,4 +49,5 @@ impl_domain_separator!(tlsn_core::connection::ConnectionInfo);
impl_domain_separator!(tlsn_core::connection::CertBinding);
impl_domain_separator!(tlsn_core::transcript::TranscriptCommitment);
impl_domain_separator!(tlsn_core::transcript::TranscriptSecret);
impl_domain_separator!(tlsn_core::transcript::encoding::EncodingCommitment);
impl_domain_separator!(tlsn_core::transcript::hash::PlaintextHash);

View File

@@ -202,14 +202,6 @@ impl SignatureVerifierProvider {
.map(|s| &**s)
.ok_or(UnknownSignatureAlgId(*alg))
}
/// Returns am empty provider.
#[cfg(any(test, feature = "fixtures"))]
pub fn empty() -> Self {
Self {
verifiers: HashMap::default(),
}
}
}
/// Signature verifier.
@@ -237,14 +229,6 @@ impl_domain_separator!(VerifyingKey);
#[error("signature verification failed: {0}")]
pub struct SignatureError(String);
impl SignatureError {
/// Creates a new error with the given message.
#[allow(clippy::should_implement_trait)]
pub fn from_str(msg: &str) -> Self {
Self(msg.to_string())
}
}
/// A signature.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Signature {

View File

@@ -1,5 +1,3 @@
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::set::RangeSet;
use tlsn_attestation::{
Attestation, AttestationConfig, CryptoProvider,
presentation::PresentationOutput,
@@ -8,11 +6,12 @@ use tlsn_attestation::{
};
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
hash::{Blake3, Blinder, HashAlgId},
fixtures::{self, ConnectionFixture, encoder_secret},
hash::Blake3,
transcript::{
Direction, Transcript, TranscriptCommitment, TranscriptSecret,
hash::{PlaintextHash, PlaintextHashSecret, hash_plaintext},
Direction, Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment,
TranscriptSecret,
encoding::{EncodingCommitment, EncodingTree},
},
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -20,7 +19,6 @@ use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
/// Tests that the attestation protocol and verification work end-to-end
#[test]
fn test_api() {
let mut rng = StdRng::seed_from_u64(0);
let mut provider = CryptoProvider::default();
// Configure signer for Notary
@@ -28,6 +26,8 @@ fn test_api() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let (sent_len, recv_len) = transcript.len();
// Plaintext encodings which the Prover obtained from GC evaluation
let encodings_provider = fixtures::encoding_provider(GET_WITH_HEADER, OK_JSON);
// At the end of the TLS connection the Prover holds the:
let ConnectionFixture {
@@ -44,38 +44,27 @@ fn test_api() {
unreachable!()
};
// Create hash commitments
let hasher = Blake3::default();
let sent_blinder: Blinder = rng.random();
let recv_blinder: Blinder = rng.random();
// Prover specifies the ranges it wants to commit to.
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
transcript_commitment_builder
.commit_sent(&(0..sent_len))
.unwrap()
.commit_recv(&(0..recv_len))
.unwrap();
let sent_idx = RangeSet::from(0..sent_len);
let recv_idx = RangeSet::from(0..recv_len);
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let sent_hash_commitment = PlaintextHash {
direction: Direction::Sent,
idx: sent_idx.clone(),
hash: hash_plaintext(&hasher, transcript.sent(), &sent_blinder),
};
// Prover constructs encoding tree.
let encoding_tree = EncodingTree::new(
&Blake3::default(),
transcripts_commitment_config.iter_encoding(),
&encodings_provider,
)
.unwrap();
let recv_hash_commitment = PlaintextHash {
direction: Direction::Received,
idx: recv_idx.clone(),
hash: hash_plaintext(&hasher, transcript.received(), &recv_blinder),
};
let sent_hash_secret = PlaintextHashSecret {
direction: Direction::Sent,
idx: sent_idx,
alg: HashAlgId::BLAKE3,
blinder: sent_blinder,
};
let recv_hash_secret = PlaintextHashSecret {
direction: Direction::Received,
idx: recv_idx,
alg: HashAlgId::BLAKE3,
blinder: recv_blinder,
let encoding_commitment = EncodingCommitment {
root: encoding_tree.root(),
secret: encoder_secret(),
};
let request_config = RequestConfig::default();
@@ -86,14 +75,8 @@ fn test_api() {
.handshake_data(server_cert_data)
.transcript(transcript)
.transcript_commitments(
vec![
TranscriptSecret::Hash(sent_hash_secret),
TranscriptSecret::Hash(recv_hash_secret),
],
vec![
TranscriptCommitment::Hash(sent_hash_commitment.clone()),
TranscriptCommitment::Hash(recv_hash_commitment.clone()),
],
vec![TranscriptSecret::Encoding(encoding_tree)],
vec![TranscriptCommitment::Encoding(encoding_commitment.clone())],
);
let (request, secrets) = request_builder.build(&provider).unwrap();
@@ -113,15 +96,12 @@ fn test_api() {
.connection_info(connection_info.clone())
// Server key Notary received during handshake
.server_ephemeral_key(server_ephemeral_key)
.transcript_commitments(vec![
TranscriptCommitment::Hash(sent_hash_commitment),
TranscriptCommitment::Hash(recv_hash_commitment),
]);
.transcript_commitments(vec![TranscriptCommitment::Encoding(encoding_commitment)]);
let attestation = attestation_builder.build(&provider).unwrap();
// Prover validates the attestation is consistent with its request.
request.validate(&attestation, &provider).unwrap();
request.validate(&attestation).unwrap();
let mut transcript_proof_builder = secrets.transcript_proof_builder();

View File

@@ -5,7 +5,7 @@ description = "This crate provides implementations of ciphers for two parties"
keywords = ["tls", "mpc", "2pc", "aes"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]
@@ -15,7 +15,7 @@ workspace = true
name = "cipher"
[dependencies]
mpz-circuits = { workspace = true, features = ["aes"] }
mpz-circuits = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
@@ -24,9 +24,11 @@ thiserror = { workspace = true }
aes = { workspace = true }
[dev-dependencies]
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ideal-vm = { workspace = true }
mpz-garble = { workspace = true }
mpz-common = { workspace = true }
mpz-ot = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }

View File

@@ -2,7 +2,7 @@
use crate::{Cipher, CtrBlock, Keystream};
use async_trait::async_trait;
use mpz_circuits::{AES128_KS, AES128_POST_KS};
use mpz_circuits::circuits::AES128;
use mpz_memory_core::binary::{Binary, U8};
use mpz_vm_core::{prelude::*, Call, Vm};
use std::fmt::Debug;
@@ -12,35 +12,13 @@ mod error;
pub use error::AesError;
use error::ErrorKind;
/// AES key schedule: 11 round keys, 16 bytes each.
type KeySchedule = Array<U8, 176>;
/// Computes AES-128.
#[derive(Default, Debug)]
pub struct Aes128 {
key: Option<Array<U8, 16>>,
key_schedule: Option<KeySchedule>,
iv: Option<Array<U8, 4>>,
}
impl Aes128 {
// Allocates key schedule.
//
// Expects the key to be already set.
fn alloc_key_schedule(&self, vm: &mut dyn Vm<Binary>) -> Result<KeySchedule, AesError> {
let ks: KeySchedule = vm
.call(
Call::builder(AES128_KS.clone())
.arg(self.key.expect("key is set"))
.build()
.expect("call should be valid"),
)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
Ok(ks)
}
}
#[async_trait]
impl Cipher for Aes128 {
type Error = AesError;
@@ -67,22 +45,18 @@ impl Cipher for Aes128 {
}
fn alloc_block(
&mut self,
&self,
vm: &mut dyn Vm<Binary>,
input: Array<U8, 16>,
) -> Result<Self::Block, Self::Error> {
self.key
let key = self
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm
.call(
Call::builder(AES128_POST_KS.clone())
.arg(ks)
Call::builder(AES128.clone())
.arg(key)
.arg(input)
.build()
.expect("call should be valid"),
@@ -93,10 +67,11 @@ impl Cipher for Aes128 {
}
fn alloc_ctr_block(
&mut self,
&self,
vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
self.key
let key = self
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self
.iv
@@ -114,15 +89,10 @@ impl Cipher for Aes128 {
vm.mark_public(counter)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm
.call(
Call::builder(AES128_POST_KS.clone())
.arg(ks)
Call::builder(AES128.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(counter)
@@ -139,11 +109,12 @@ impl Cipher for Aes128 {
}
fn alloc_keystream(
&mut self,
&self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
self.key
let key = self
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self
.iv
@@ -172,15 +143,10 @@ impl Cipher for Aes128 {
let blocks = inputs
.into_iter()
.map(|(explicit_nonce, counter)| {
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm
.call(
Call::builder(AES128_POST_KS.clone())
.arg(ks)
Call::builder(AES128.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(counter)
@@ -206,12 +172,15 @@ mod tests {
use super::*;
use crate::Cipher;
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_memory_core::{
binary::{Binary, U8},
correlated::Delta,
Array, MemoryExt, Vector, ViewExt,
};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{Execute, Vm};
use rand::{rngs::StdRng, SeedableRng};
#[tokio::test]
async fn test_aes_ctr() {
@@ -221,11 +190,10 @@ mod tests {
let start_counter = 3u32;
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let (mut gen, mut ev) = mock_vm();
let mut aes_gen = setup_ctr(key, iv, &mut gen);
let mut aes_ev = setup_ctr(key, iv, &mut ev);
let aes_gen = setup_ctr(key, iv, &mut gen);
let aes_ev = setup_ctr(key, iv, &mut ev);
let msg = vec![42u8; 128];
@@ -284,11 +252,10 @@ mod tests {
let input = [5_u8; 16];
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let (mut gen, mut ev) = mock_vm();
let mut aes_gen = setup_block(key, &mut gen);
let mut aes_ev = setup_block(key, &mut ev);
let aes_gen = setup_block(key, &mut gen);
let aes_ev = setup_block(key, &mut ev);
let block_ref_gen: Array<U8, 16> = gen.alloc().unwrap();
gen.mark_public(block_ref_gen).unwrap();
@@ -327,6 +294,18 @@ mod tests {
assert_eq!(ciphertext_gen, expected);
}
fn mock_vm() -> (impl Vm<Binary>, impl Vm<Binary>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
fn setup_ctr(key: [u8; 16], iv: [u8; 4], vm: &mut dyn Vm<Binary>) -> Aes128 {
let key_ref: Array<U8, 16> = vm.alloc().unwrap();
vm.mark_public(key_ref).unwrap();

View File

@@ -55,7 +55,7 @@ pub trait Cipher {
/// Allocates a single block in ECB mode.
fn alloc_block(
&mut self,
&self,
vm: &mut dyn Vm<Binary>,
input: Self::Block,
) -> Result<Self::Block, Self::Error>;
@@ -63,7 +63,7 @@ pub trait Cipher {
/// Allocates a single block in counter mode.
#[allow(clippy::type_complexity)]
fn alloc_ctr_block(
&mut self,
&self,
vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
@@ -75,7 +75,7 @@ pub trait Cipher {
/// * `len` - Length of the stream in bytes.
#[allow(clippy::type_complexity)]
fn alloc_keystream(
&mut self,
&self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-deap"
version = "0.1.0-alpha.14-pre"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]
@@ -19,8 +19,11 @@ futures = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
[dev-dependencies]
mpz-circuits = { workspace = true, features = ["aes"] }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ideal-vm = { workspace = true }
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true }
mpz-ot = { workspace = true }
mpz-zk = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
rand06-compat = { workspace = true }

View File

@@ -15,7 +15,7 @@ use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
Call, Callable, Execute, Vm, VmError,
};
use rangeset::{ops::Set, set::RangeSet};
use rangeset::{Difference, RangeSet, UnionMut};
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
type Error = DeapError;
@@ -210,12 +210,10 @@ where
}
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let slice_range = slice.to_range();
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower {
for input in input_minus_follower.iter_ranges() {
zk.commit_raw(
self.memory_map
.try_get(Slice::from_range_unchecked(input))?,
@@ -268,7 +266,7 @@ where
mpc.mark_private_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
}
@@ -284,7 +282,7 @@ where
mpc.mark_blind_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
Role::Follower => {
@@ -384,27 +382,37 @@ enum ErrorRepr {
#[cfg(test)]
mod tests {
use mpz_circuits::AES128;
use mpz_circuits::circuits::AES128;
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::{cot::ideal_cot, rcot::ideal_rcot};
use mpz_vm_core::{
memory::{binary::U8, Array},
memory::{binary::U8, correlated::Delta, Array},
prelude::*,
};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{rngs::StdRng, SeedableRng};
use super::*;
#[tokio::test]
async fn test_deap() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let leader_mpc = IdealVm::new();
let leader_zk = IdealVm::new();
let follower_mpc = IdealVm::new();
let follower_zk = IdealVm::new();
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
let (ct_leader, ct_follower) = futures::join!(
async {
@@ -470,15 +478,21 @@ mod tests {
#[tokio::test]
async fn test_deap_desync_memory() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let leader_mpc = IdealVm::new();
let leader_zk = IdealVm::new();
let follower_mpc = IdealVm::new();
let follower_zk = IdealVm::new();
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
// Desynchronize the memories.
let _ = leader.zk().alloc_raw(1).unwrap();
@@ -550,15 +564,21 @@ mod tests {
// detection by the follower.
#[tokio::test]
async fn test_malicious() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let leader_mpc = IdealVm::new();
let leader_zk = IdealVm::new();
let follower_mpc = IdealVm::new();
let follower_zk = IdealVm::new();
let gb = Garbler::new(cot_send, [1u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
let (_, follower_res) = futures::join!(
async {

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_vm_core::{memory::Slice, VmError};
use rangeset::ops::Set;
use rangeset::Subset;
/// A mapping between the memories of the MPC and ZK VMs.
#[derive(Debug, Default)]

View File

@@ -5,7 +5,7 @@ description = "A 2PC implementation of TLS HMAC-SHA256 PRF"
keywords = ["tls", "mpc", "2pc", "hmac", "sha256"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]
@@ -20,20 +20,28 @@ mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-hash = { workspace = true }
sha2 = { workspace = true, features = ["compress"] }
rand = { workspace = true }
sha2 = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ideal-vm = { workspace = true }
criterion = { workspace = true, features = ["async_tokio"] }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
hex = { workspace = true }
hmac = { workspace = true }
ring = { workspace = true }
rstest = { workspace = true }
sha2 = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
[[bench]]
name = "prf"
name = "tls12"
harness = false
[[bench]]
name = "tls13"
harness = false

View File

@@ -2,31 +2,37 @@
use criterion::{criterion_group, criterion_main, Criterion};
use hmac_sha256::{Mode, MpcPrf};
use hmac_sha256::{Mode, Tls12Prf};
use mpz_common::context::test_mt_context;
use mpz_ideal_vm::IdealVm;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{
memory::{binary::U8, Array},
memory::{binary::U8, correlated::Delta, Array},
prelude::*,
Execute,
};
use rand::{rngs::StdRng, SeedableRng};
#[allow(clippy::unit_arg)]
fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("prf");
let mut group = c.benchmark_group("tls12");
group.sample_size(10);
let rt = tokio::runtime::Runtime::new().unwrap();
group.bench_function("prf_normal", |b| b.to_async(&rt).iter(|| prf(Mode::Normal)));
group.bench_function("prf_reduced", |b| {
b.to_async(&rt).iter(|| prf(Mode::Reduced))
group.bench_function("tls12_normal", |b| {
b.to_async(&rt).iter(|| tls12(Mode::Normal))
});
group.bench_function("tls12_reduced", |b| {
b.to_async(&rt).iter(|| tls12(Mode::Reduced))
});
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
async fn prf(mode: Mode) {
async fn tls12(mode: Mode) {
let mut rng = StdRng::seed_from_u64(0);
let pms = [42u8; 32];
let client_random = [69u8; 32];
let server_random: [u8; 32] = [96u8; 32];
@@ -35,8 +41,11 @@ async fn prf(mode: Mode) {
let mut leader_ctx = leader_exec.new_context().await.unwrap();
let mut follower_ctx = follower_exec.new_context().await.unwrap();
let mut leader_vm = IdealVm::new();
let mut follower_vm = IdealVm::new();
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
leader_vm.mark_public(leader_pms).unwrap();
@@ -48,8 +57,8 @@ async fn prf(mode: Mode) {
follower_vm.assign(follower_pms, pms).unwrap();
follower_vm.commit(follower_pms).unwrap();
let mut leader = MpcPrf::new(mode);
let mut follower = MpcPrf::new(mode);
let mut leader = Tls12Prf::new(mode);
let mut follower = Tls12Prf::new(mode);
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();

View File

@@ -0,0 +1,139 @@
#![allow(clippy::let_underscore_future)]
use criterion::{criterion_group, criterion_main, Criterion};
use hmac_sha256::{Mode, Role, Tls13KeySched};
use mpz_common::context::test_mt_context;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
correlated::Delta,
Array,
},
prelude::*,
Execute, Vm,
};
use rand::{rngs::StdRng, SeedableRng};
#[allow(clippy::unit_arg)]
fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("tls13");
group.sample_size(10);
let rt = tokio::runtime::Runtime::new().unwrap();
group.bench_function("tls13_normal", |b| {
b.to_async(&rt).iter(|| tls13(Mode::Normal))
});
group.bench_function("tls13_reduced", |b| {
b.to_async(&rt).iter(|| tls13(Mode::Reduced))
});
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
async fn tls13(mode: Mode) {
let mut rng = StdRng::seed_from_u64(0);
let pms = [42u8; 32];
let (mut leader_exec, mut follower_exec) = test_mt_context(8);
let mut leader_ctx = leader_exec.new_context().await.unwrap();
let mut follower_ctx = follower_exec.new_context().await.unwrap();
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
fn setup_ks(
vm: &mut (dyn Vm<Binary> + Send),
pms: [u8; 32],
mode: Mode,
role: Role,
) -> Tls13KeySched {
let secret: Array<U8, 32> = vm.alloc().unwrap();
vm.mark_public(secret).unwrap();
vm.assign(secret, pms).unwrap();
vm.commit(secret).unwrap();
let mut ks = Tls13KeySched::new(mode, role);
ks.alloc(vm, secret).unwrap();
ks
}
let mut leader_ks = setup_ks(&mut leader_vm, pms, mode, Role::Leader);
let mut follower_ks = setup_ks(&mut follower_vm, pms, mode, Role::Follower);
while leader_ks.wants_flush() || follower_ks.wants_flush() {
tokio::try_join!(
async {
leader_ks.flush(&mut leader_vm).unwrap();
leader_vm.execute_all(&mut leader_ctx).await
},
async {
follower_ks.flush(&mut follower_vm).unwrap();
follower_vm.execute_all(&mut follower_ctx).await
}
)
.unwrap();
}
let hello_hash = [1u8; 32];
leader_ks.set_hello_hash(hello_hash).unwrap();
follower_ks.set_hello_hash(hello_hash).unwrap();
while leader_ks.wants_flush() || follower_ks.wants_flush() {
tokio::try_join!(
async {
leader_ks.flush(&mut leader_vm).unwrap();
leader_vm.execute_all(&mut leader_ctx).await
},
async {
follower_ks.flush(&mut follower_vm).unwrap();
follower_vm.execute_all(&mut follower_ctx).await
}
)
.unwrap();
}
leader_ks.continue_to_app_keys().unwrap();
follower_ks.continue_to_app_keys().unwrap();
while leader_ks.wants_flush() || follower_ks.wants_flush() {
tokio::try_join!(
async {
leader_ks.flush(&mut leader_vm).unwrap();
leader_vm.execute_all(&mut leader_ctx).await
},
async {
follower_ks.flush(&mut follower_vm).unwrap();
follower_vm.execute_all(&mut follower_ctx).await
}
)
.unwrap();
}
let handshake_hash = [2u8; 32];
leader_ks.set_handshake_hash(handshake_hash).unwrap();
follower_ks.set_handshake_hash(handshake_hash).unwrap();
while leader_ks.wants_flush() || follower_ks.wants_flush() {
tokio::try_join!(
async {
leader_ks.flush(&mut leader_vm).unwrap();
leader_vm.execute_all(&mut leader_ctx).await
},
async {
follower_ks.flush(&mut follower_vm).unwrap();
follower_vm.execute_all(&mut follower_ctx).await
}
)
.unwrap();
}
}

View File

@@ -1,10 +1,10 @@
//! PRF modes.
//! Modes of operation.
/// Modes for the PRF.
#[derive(Debug, Clone, Copy)]
/// Modes for the TLS 1.2 PRF and the TLS 1.3 key schedule.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Mode {
/// Computes some hashes locally.
Reduced,
/// Computes the whole PRF in MPC.
/// Computes the whole function in MPC.
Normal,
}

View File

@@ -3,15 +3,15 @@ use std::error::Error;
use mpz_hash::sha256::Sha256Error;
/// A PRF error.
/// An error type used by the functionalities of this crate.
#[derive(Debug, thiserror::Error)]
pub struct PrfError {
pub struct FError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn Error + Send + Sync>>,
}
impl PrfError {
impl FError {
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync>>,
@@ -34,7 +34,7 @@ impl PrfError {
}
}
impl From<Sha256Error> for PrfError {
impl From<Sha256Error> for FError {
fn from(value: Sha256Error) -> Self {
Self::new(ErrorKind::Hash, value)
}
@@ -47,7 +47,7 @@ pub(crate) enum ErrorKind {
Hash,
}
impl fmt::Display for PrfError {
impl fmt::Display for FError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
ErrorKind::Vm => write!(f, "vm error")?,

View File

@@ -2,7 +2,7 @@
//!
//! HMAC-SHA256 is defined as
//!
//! HMAC(m) = H((key' xor opad) || H((key' xor ipad) || m))
//! HMAC(key, m) = H((key' xor opad) || H((key' xor ipad) || m))
//!
//! * H - SHA256 hash function
//! * key' - key padded with zero bytes to 64 bytes (we do not support longer
@@ -11,169 +11,307 @@
//! * ipad - 64 bytes of 0x36
//! * m - message
//!
//! This implementation computes HMAC-SHA256 using intermediate results
//! `outer_partial` and `inner_local`. Then HMAC(m) = H(outer_partial ||
//! inner_local)
//! We describe HMAC in terms of the SHA-256 compression function
//! C(IV, m), where `IV` is the hash state, `m` is the input block,
//! and the output is the updated state.
//!
//! * `outer_partial` - key' xor opad
//! * `inner_local` - H((key' xor ipad) || m)
//! HMAC(m) = C( C(IV, key' xor opad), C( C(IV, key' xor ipad), m) )
//!
//! Throughout this crate we use the following terminology for
//! intermediate states:
//!
//! * `outer_partial` — C(IV, key' ⊕ opad)
//! * `inner_partial` — C(IV, key' ⊕ ipad)
//! * `inner_local` — C(inner_partial, m)
//!
//! The final value is then computed as:
//!
//! HMAC(m) = C(outer_partial, inner_local)
use std::sync::Arc;
use crate::{
hmac::{normal::HmacNormal, reduced::HmacReduced},
sha256, state_to_bytes, Mode,
};
use mpz_circuits::circuits::xor;
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array,
Array, MemoryExt, Vector, ViewExt,
},
Vm,
Call, CallableExt, Vm,
};
use crate::PrfError;
use crate::FError;
pub(crate) mod clear;
pub(crate) mod normal;
pub(crate) mod reduced;
/// Inner padding of HMAC.
pub(crate) const IPAD: [u8; 64] = [0x36; 64];
/// Outer padding of HMAC.
pub(crate) const OPAD: [u8; 64] = [0x5c; 64];
/// Initial IV of SHA256.
pub(crate) const SHA256_IV: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
];
/// Computes HMAC-SHA256
/// Functionality for HMAC computation with a private key and a public message.
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) enum Hmac {
Reduced(reduced::HmacReduced),
Normal(normal::HmacNormal),
}
impl Hmac {
/// Allocates a new HMAC with the given `key`.
pub(crate) fn alloc(
vm: &mut dyn Vm<Binary>,
key: Vector<U8>,
mode: Mode,
) -> Result<Self, FError> {
match mode {
Mode::Reduced => Ok(Hmac::Reduced(HmacReduced::alloc(vm, key)?)),
Mode::Normal => Ok(Hmac::Normal(HmacNormal::alloc(vm, key)?)),
}
}
/// Whether this functionality needs to be flushed.
#[allow(dead_code)]
pub(crate) fn wants_flush(&self) -> bool {
match self {
Hmac::Reduced(hmac) => hmac.wants_flush(),
Hmac::Normal(hmac) => hmac.wants_flush(),
}
}
/// Flushes the functionality.
#[allow(dead_code)]
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
match self {
Hmac::Reduced(hmac) => hmac.flush(vm),
Hmac::Normal(hmac) => hmac.flush(),
}
}
/// Returns HMAC output.
#[allow(dead_code)]
pub(crate) fn output(&self) -> Result<Array<U8, 32>, FError> {
match self {
Hmac::Reduced(hmac) => Ok(hmac.output()),
Hmac::Normal(hmac) => hmac.output(),
}
}
/// Creates a new allocated instance of HMAC from another instance.
pub(crate) fn from_other(vm: &mut dyn Vm<Binary>, other: &Self) -> Result<Self, FError> {
match other {
Hmac::Reduced(hmac) => Ok(Hmac::Reduced(HmacReduced::from_other(vm, hmac)?)),
Hmac::Normal(hmac) => Ok(Hmac::Normal(HmacNormal::from_other(hmac)?)),
}
}
}
/// Computes HMAC-SHA256.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `outer_partial` - (key' xor opad)
/// * `inner_local` - H((key' xor ipad) || m)
/// * `outer_partial` - outer_partial.
/// * `inner_local` - inner_local.
pub(crate) fn hmac_sha256(
vm: &mut dyn Vm<Binary>,
mut outer_partial: Sha256,
inner_local: Array<U8, 32>,
) -> Result<Array<U8, 32>, PrfError> {
) -> Result<Array<U8, 32>, FError> {
outer_partial.update(&inner_local.into());
outer_partial.compress(vm)?;
outer_partial.finalize(vm).map_err(PrfError::from)
outer_partial.finalize(vm).map_err(FError::from)
}
/// Depending on the provided `mask` computes and returns outer_partial or
/// inner_partial for HMAC-SHA256.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `key` - Key to pad and xor.
/// * `mask`- Mask used for padding.
fn compute_partial(
vm: &mut dyn Vm<Binary>,
key: Vector<U8>,
mask: [u8; 64],
) -> Result<Sha256, FError> {
let xor = Arc::new(xor(8 * 64));
let additional_len = 64 - key.len();
let padding = vec![0_u8; additional_len];
let padding_ref: Vector<U8> = vm.alloc_vec(additional_len).map_err(FError::vm)?;
vm.mark_public(padding_ref).map_err(FError::vm)?;
vm.assign(padding_ref, padding).map_err(FError::vm)?;
vm.commit(padding_ref).map_err(FError::vm)?;
let mask_ref: Array<U8, 64> = vm.alloc().map_err(FError::vm)?;
vm.mark_public(mask_ref).map_err(FError::vm)?;
vm.assign(mask_ref, mask).map_err(FError::vm)?;
vm.commit(mask_ref).map_err(FError::vm)?;
let xor = Call::builder(xor)
.arg(key)
.arg(padding_ref)
.arg(mask_ref)
.build()
.map_err(FError::vm)?;
let key_padded: Vector<U8> = vm.call(xor).map_err(FError::vm)?;
let mut sha = Sha256::new_with_init(vm)?;
sha.update(&key_padded);
sha.compress(vm)?;
Ok(sha)
}
/// Computes and assigns inner_local.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `inner_local` - VM reference to assign to.
/// * `inner_partial` - inner_partial.
/// * `msg` - Message to be compressed.
pub(crate) fn assign_inner_local(
vm: &mut dyn Vm<Binary>,
inner_local: Array<U8, 32>,
inner_partial: [u32; 8],
msg: &[u8],
) -> Result<(), FError> {
let inner_local_value = sha256(inner_partial, 64, msg);
vm.assign(inner_local, state_to_bytes(inner_local_value))
.map_err(FError::vm)?;
vm.commit(inner_local).map_err(FError::vm)?;
Ok(())
}
#[cfg(test)]
mod tests {
use crate::{
hmac::hmac_sha256,
sha256, state_to_bytes,
test_utils::{compute_inner_local, compute_outer_partial},
};
use super::*;
use crate::test_utils::mock_vm;
use hmac::{Hmac as HmacReference, Mac};
use mpz_common::context::test_st_context;
use mpz_hash::sha256::Sha256;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{
binary::{U32, U8},
Array, MemoryExt, ViewExt,
},
memory::{MemoryExt, ViewExt},
Execute,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rstest::*;
use sha2::Sha256;
#[test]
fn test_hmac_reference() {
let (inputs, references) = test_fixtures();
for (input, &reference) in inputs.iter().zip(references.iter()) {
let outer_partial = compute_outer_partial(input.0.clone());
let inner_local = compute_inner_local(input.0.clone(), &input.1);
let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local));
assert_eq!(state_to_bytes(hmac), reference);
}
}
#[rstest]
#[case::normal(Mode::Normal)]
#[case::reduced(Mode::Reduced)]
#[tokio::test]
async fn test_hmac_circuit() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
async fn test_hmac(#[case] mode: Mode) {
let mut rng = StdRng::from_seed([2; 32]);
let (inputs, references) = test_fixtures();
for (input, &reference) in inputs.iter().zip(references.iter()) {
let outer_partial = compute_outer_partial(input.0.clone());
let inner_local = compute_inner_local(input.0.clone(), &input.1);
for _ in 0..10 {
let key: [u8; 32] = rng.random();
let msg: [u8; 32] = rng.random();
let outer_partial_leader: Array<U32, 8> = leader.alloc().unwrap();
leader.mark_public(outer_partial_leader).unwrap();
leader.assign(outer_partial_leader, outer_partial).unwrap();
leader.commit(outer_partial_leader).unwrap();
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let inner_local_leader: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(inner_local_leader).unwrap();
leader
.assign(inner_local_leader, state_to_bytes(inner_local))
.unwrap();
leader.commit(inner_local_leader).unwrap();
let vm = &mut leader;
let key_ref = vm.alloc_vec(32).unwrap();
vm.mark_public(key_ref).unwrap();
vm.assign(key_ref, key.to_vec()).unwrap();
vm.commit(key_ref).unwrap();
let mut hmac_leader = Hmac::alloc(vm, key_ref, mode).unwrap();
let hmac_leader = hmac_sha256(
&mut leader,
Sha256::new_from_state(outer_partial_leader, 1),
inner_local_leader,
)
.unwrap();
let hmac_leader = leader.decode(hmac_leader).unwrap();
if mode == Mode::Reduced {
if let Hmac::Reduced(ref mut hmac) = hmac_leader {
hmac.set_msg(&msg).unwrap();
};
} else if let Hmac::Normal(ref mut hmac) = hmac_leader {
let msg_ref = vm.alloc_vec(msg.len()).unwrap();
vm.mark_public(msg_ref).unwrap();
vm.assign(msg_ref, msg.to_vec()).unwrap();
vm.commit(msg_ref).unwrap();
hmac.set_msg(vm, &[msg_ref]).unwrap();
}
let leader_out = hmac_leader.output().unwrap();
let mut leader_out = vm.decode(leader_out).unwrap();
let outer_partial_follower: Array<U32, 8> = follower.alloc().unwrap();
follower.mark_public(outer_partial_follower).unwrap();
follower
.assign(outer_partial_follower, outer_partial)
.unwrap();
follower.commit(outer_partial_follower).unwrap();
let vm = &mut follower;
let key_ref = vm.alloc_vec(32).unwrap();
vm.mark_public(key_ref).unwrap();
vm.assign(key_ref, key.to_vec()).unwrap();
vm.commit(key_ref).unwrap();
let mut hmac_follower = Hmac::alloc(vm, key_ref, mode).unwrap();
let inner_local_follower: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(inner_local_follower).unwrap();
follower
.assign(inner_local_follower, state_to_bytes(inner_local))
.unwrap();
follower.commit(inner_local_follower).unwrap();
if mode == Mode::Reduced {
if let Hmac::Reduced(ref mut hmac) = hmac_follower {
hmac.set_msg(&msg).unwrap();
};
} else if let Hmac::Normal(ref mut hmac) = hmac_follower {
let msg_ref = vm.alloc_vec(msg.len()).unwrap();
vm.mark_public(msg_ref).unwrap();
vm.assign(msg_ref, msg.to_vec()).unwrap();
vm.commit(msg_ref).unwrap();
hmac.set_msg(vm, &[msg_ref]).unwrap();
}
let follower_out = hmac_follower.output().unwrap();
let mut follower_out = vm.decode(follower_out).unwrap();
let hmac_follower = hmac_sha256(
&mut follower,
Sha256::new_from_state(outer_partial_follower, 1),
inner_local_follower,
)
.unwrap();
let hmac_follower = follower.decode(hmac_follower).unwrap();
let (hmac_leader, hmac_follower) = tokio::try_join!(
tokio::try_join!(
async {
assert!(hmac_leader.wants_flush());
hmac_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await.unwrap();
hmac_leader.await
// In reduced mode two flushes are required.
if mode == Mode::Reduced {
assert!(hmac_leader.wants_flush());
hmac_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await.unwrap();
}
assert!(!hmac_leader.wants_flush());
Ok::<(), Box<dyn std::error::Error>>(())
},
async {
assert!(hmac_follower.wants_flush());
hmac_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await.unwrap();
hmac_follower.await
// On reduced mode two flushes are required.
if mode == Mode::Reduced {
assert!(hmac_follower.wants_flush());
hmac_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await.unwrap();
}
assert!(!hmac_follower.wants_flush());
Ok::<(), Box<dyn std::error::Error>>(())
}
)
.unwrap();
assert_eq!(hmac_leader, hmac_follower);
assert_eq!(hmac_leader, reference);
let leader_out = leader_out.try_recv().unwrap().unwrap();
let follower_out = follower_out.try_recv().unwrap().unwrap();
let mut hmac_ref = HmacReference::<Sha256>::new_from_slice(&key).unwrap();
hmac_ref.update(&msg);
assert_eq!(leader_out, follower_out);
assert_eq!(leader_out, *hmac_ref.finalize().into_bytes());
}
}
#[allow(clippy::type_complexity)]
fn test_fixtures() -> (Vec<(Vec<u8>, Vec<u8>)>, Vec<[u8; 32]>) {
let test_vectors: Vec<(Vec<u8>, Vec<u8>)> = vec![
(
hex::decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b").unwrap(),
hex::decode("4869205468657265").unwrap(),
),
(
hex::decode("4a656665").unwrap(),
hex::decode("7768617420646f2079612077616e7420666f72206e6f7468696e673f").unwrap(),
),
];
let expected: Vec<[u8; 32]> = vec![
hex::decode("b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7")
.unwrap()
.try_into()
.unwrap(),
hex::decode("5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843")
.unwrap()
.try_into()
.unwrap(),
];
(test_vectors, expected)
}
}

View File

@@ -0,0 +1,72 @@
//! Computation of HMAC-SHA256 on cleartext values.
use crate::{
compress_256,
hmac::{IPAD, OPAD, SHA256_IV},
sha256, state_to_bytes,
};
/// Depending on the provided `mask` computes and returns outer_partial or
/// inner_partial for HMAC-SHA256.
fn compute_partial(key: &[u8], mask: &[u8; 64]) -> [u32; 8] {
assert!(key.len() <= 64);
let mut key = key.to_vec();
key.resize(64, 0_u8);
let key_padded: [u8; 64] = key
.into_iter()
.zip(mask)
.map(|(b, mask)| b ^ mask)
.collect::<Vec<u8>>()
.try_into()
.expect("output length is 64 bytes");
compress_256(SHA256_IV, &key_padded)
}
/// Computes and returns inner_partial for HMAC-SHA256.
pub(crate) fn compute_inner_partial(key: &[u8]) -> [u32; 8] {
compute_partial(key, &IPAD)
}
/// Computes and returns outer_partial for HMAC-SHA256.
pub(crate) fn compute_outer_partial(key: &[u8]) -> [u32; 8] {
compute_partial(key, &OPAD)
}
/// Computes and returns inner_local for HMAC-SHA256.
fn compute_inner_local(key: &[u8], msg: &[u8]) -> [u32; 8] {
sha256(compute_inner_partial(key), 64, msg)
}
/// Computes and returns the HMAC-SHA256 output.
pub(crate) fn hmac_sha256(key: &[u8], msg: &[u8]) -> [u8; 32] {
let outer_partial = compute_outer_partial(key);
let inner_local = compute_inner_local(key, msg);
let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local));
state_to_bytes(hmac)
}
#[cfg(test)]
mod tests {
use super::*;
use hmac::{Hmac, Mac};
use rand::{rngs::StdRng, Rng, SeedableRng};
use sha2::Sha256;
#[test]
fn test_hmac_sha256() {
let mut rng = StdRng::from_seed([1; 32]);
for _ in 0..10 {
let key: [u8; 32] = rng.random();
let msg: [u8; 32] = rng.random();
let mut mac =
Hmac::<Sha256>::new_from_slice(&key).expect("HMAC can take key of any size");
mac.update(&msg);
assert_eq!(hmac_sha256(&key, &msg), *mac.finalize().into_bytes())
}
}
}

View File

@@ -0,0 +1,141 @@
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U32, U8},
Array, MemoryExt, Vector, ViewExt,
},
Vm,
};
use crate::{
hmac::{compute_partial, hmac_sha256, IPAD, OPAD},
FError,
};
/// Functionality for HMAC computation with a private key and a public message.
///
/// Used in conjunction with [crate::Mode::Normal].
#[derive(Debug, Clone)]
pub(crate) struct HmacNormal {
inner_partial: Sha256,
outer_partial: Sha256,
output: Option<Array<U8, 32>>,
state: State,
}
impl HmacNormal {
/// Allocates a new HMAC with the given `key`.
pub(crate) fn alloc(vm: &mut dyn Vm<Binary>, key: Vector<U8>) -> Result<Self, FError> {
Ok(Self {
inner_partial: compute_partial(vm, key, IPAD)?,
outer_partial: compute_partial(vm, key, OPAD)?,
output: None,
state: State::WantsMsg,
})
}
/// Allocates a new HMAC with the given `inner_partial` and
/// `outer_partial`.
pub(crate) fn alloc_with_state(
vm: &mut dyn Vm<Binary>,
inner_partial: [u32; 8],
outer_partial: [u32; 8],
) -> Result<Self, FError> {
let inner_p: Array<U32, 8> = vm.alloc().map_err(FError::vm)?;
vm.mark_public(inner_p).map_err(FError::vm)?;
vm.assign(inner_p, inner_partial).map_err(FError::vm)?;
vm.commit(inner_p).map_err(FError::vm)?;
let inner = Sha256::new_from_state(inner_p, 1);
let outer_p: Array<U32, 8> = vm.alloc().map_err(FError::vm)?;
vm.mark_public(outer_p).map_err(FError::vm)?;
vm.assign(outer_p, outer_partial).map_err(FError::vm)?;
vm.commit(outer_p).map_err(FError::vm)?;
let outer = Sha256::new_from_state(outer_p, 1);
Ok(Self {
inner_partial: inner,
outer_partial: outer,
output: None,
state: State::WantsMsg,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
matches!(self.state, State::MsgSet)
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self) -> Result<(), FError> {
if let State::MsgSet = self.state {
self.state = State::Complete;
}
Ok(())
}
/// Sets an HMAC message `msg`.
///
/// The message is a slice of vectors which will be concatenated.
pub(crate) fn set_msg(
&mut self,
vm: &mut dyn Vm<Binary>,
msg: &[Vector<U8>],
) -> Result<(), FError> {
match self.state {
State::WantsMsg => (),
_ => return Err(FError::state("must be in WantsMsg state to set message")),
}
msg.iter().for_each(|m| self.inner_partial.update(m));
self.inner_partial.compress(vm).map_err(FError::vm)?;
let inner_local = self.inner_partial.finalize(vm).map_err(FError::vm)?;
let out = hmac_sha256(vm, self.outer_partial.clone(), inner_local)?;
self.output = Some(out);
self.state = State::MsgSet;
Ok(())
}
/// Returns HMAC output.
pub(crate) fn output(&self) -> Result<Array<U8, 32>, FError> {
match self.state {
State::MsgSet | State::Complete => Ok(self
.output
.expect("output is available when message is set")),
_ => Err(FError::state(
"must be in MsgSet or Complete state to return output",
)),
}
}
/// Creates a new allocated instance of HMAC from another instance.
pub(crate) fn from_other(other: &Self) -> Result<Self, FError> {
match other.state {
State::WantsMsg => Ok(Self {
inner_partial: other.inner_partial.clone(),
outer_partial: other.outer_partial.clone(),
output: None,
state: State::WantsMsg,
}),
_ => Err(FError::state("other must be in WantsMsg state")),
}
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
/// State of [HmacNormal].
#[derive(Debug, Clone)]
pub(crate) enum State {
WantsMsg,
/// The state after the message has been set.
MsgSet,
Complete,
}

View File

@@ -0,0 +1,163 @@
use crate::hmac::{assign_inner_local, compute_partial, hmac_sha256, IPAD, OPAD};
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U32, U8},
Array, MemoryExt, Vector, ViewExt,
},
Vm,
};
use crate::FError;
/// Functionality for HMAC computation with a private key and a public message.
///
/// Used in conjunction with [crate::Mode::Reduced].
#[derive(Debug)]
pub(crate) struct HmacReduced {
outer_partial: Sha256,
inner_local: Array<U8, 32>,
inner_partial: Array<U32, 8>,
msg: Option<Vec<u8>>,
output: Array<U8, 32>,
state: State,
}
impl HmacReduced {
/// Allocates a new HMAC with the given `key`.
pub(crate) fn alloc(vm: &mut dyn Vm<Binary>, key: Vector<U8>) -> Result<Self, FError> {
let inner_partial = compute_partial(vm, key, IPAD)?;
let outer_partial = compute_partial(vm, key, OPAD)?;
let (inner_partial, _) = inner_partial
.state()
.expect("state should be set for inner_partial");
// Decode as soon as the value is computed.
std::mem::drop(vm.decode(inner_partial).map_err(FError::vm)?);
let inner_local: Array<U8, 32> = vm.alloc().map_err(FError::vm)?;
vm.mark_public(inner_local).map_err(FError::vm)?;
let out = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
Ok(Self {
outer_partial,
inner_local,
inner_partial,
msg: None,
output: out,
state: State::WantsInnerPartial,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
match self.state {
State::WantsInnerPartial => true,
State::WantsMsg { .. } => self.msg.is_some(),
_ => false,
}
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
let state = self.state.take();
match state {
State::WantsInnerPartial => {
let mut inner_partial = vm.decode(self.inner_partial).map_err(FError::vm)?;
let Some(inner_partial) = inner_partial.try_recv().map_err(FError::vm)? else {
self.state = State::WantsInnerPartial;
return Ok(());
};
self.state = State::WantsMsg { inner_partial };
// Recurse.
self.flush(vm)?;
}
State::WantsMsg { inner_partial } => {
// output is Some after msg was set
if self.msg.is_some() {
assign_inner_local(
vm,
self.inner_local,
inner_partial,
&self.msg.clone().unwrap(),
)?;
self.state = State::Complete;
} else {
self.state = State::WantsMsg { inner_partial };
}
}
_ => self.state = state,
}
Ok(())
}
/// Sets the HMAC message.
pub(crate) fn set_msg(&mut self, msg: &[u8]) -> Result<(), FError> {
match self.msg {
None => self.msg = Some(msg.to_vec()),
Some(_) => return Err(FError::state("message has already been set")),
}
Ok(())
}
/// Whether the HMAC message has been set.
pub(crate) fn is_msg_set(&mut self) -> bool {
self.msg.is_some()
}
/// Returns the HMAC output.
pub(crate) fn output(&self) -> Array<U8, 32> {
self.output
}
/// Creates a new allocated instance of HMAC from another instance.
pub(crate) fn from_other(vm: &mut dyn Vm<Binary>, other: &Self) -> Result<Self, FError> {
match other.state {
State::WantsInnerPartial => {
let inner_local: Array<U8, 32> = vm.alloc().map_err(FError::vm)?;
vm.mark_public(inner_local).map_err(FError::vm)?;
let out = hmac_sha256(vm, other.outer_partial.clone(), inner_local)?;
Ok(Self {
outer_partial: other.outer_partial.clone(),
inner_local,
inner_partial: other.inner_partial,
msg: None,
output: out,
state: State::WantsInnerPartial,
})
}
_ => Err(FError::state("other must be in WantsInnerPartial state")),
}
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
/// State of [HmacReduced].
#[derive(Debug, Clone)]
pub(crate) enum State {
/// Wants the decoded inner_partial plaintext.
WantsInnerPartial,
/// Wants the message to be set.
WantsMsg {
inner_partial: [u32; 8],
},
Complete,
Error,
}
impl State {
pub(crate) fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
}

View File

@@ -0,0 +1,292 @@
//! `HKDF-Expand-Label` function as defined in TLS 1.3.
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Vector,
},
Vm,
};
use crate::{
hmac::{clear, Hmac},
kdf::expand::label::make_hkdf_label,
FError, Mode,
};
pub(crate) mod label;
pub(crate) mod normal;
pub(crate) mod reduced;
/// A zero_length HKDF-Expand-Label context.
pub(crate) const EMPTY_CTX: [u8; 0] = [];
/// Functionality for computing `HKDF-Expand-Label` with a private secret
/// and public label and context.
#[derive(Debug)]
pub(crate) enum HkdfExpand {
Reduced(reduced::HkdfExpand),
Normal(normal::HkdfExpand),
}
impl HkdfExpand {
/// Allocates a new HKDF-Expand-Label with the `hmac`
/// instantiated with the secret.
pub(crate) fn alloc(
mode: Mode,
vm: &mut dyn Vm<Binary>,
// Partial hash states of the secret.
hmac: Hmac,
// Human-readable label.
label: &'static [u8],
// Context.
ctx: Option<&[u8]>,
// Context length.
ctx_len: usize,
// Output length.
out_len: usize,
) -> Result<Self, FError> {
let prf = match mode {
Mode::Reduced => {
if let Hmac::Reduced(hmac) = hmac {
let mut hkdf = reduced::HkdfExpand::alloc(hmac, label, out_len)?;
if let Some(ctx) = ctx {
hkdf.set_ctx(ctx)?;
}
Self::Reduced(hkdf)
} else {
unreachable!("modes always match");
}
}
Mode::Normal => {
if let Hmac::Normal(hmac) = hmac {
let mut hkdf = normal::HkdfExpand::alloc(vm, hmac, label, ctx_len, out_len)?;
if let Some(ctx) = ctx {
hkdf.set_ctx(ctx)?;
}
Self::Normal(hkdf)
} else {
unreachable!("modes always match");
}
}
};
Ok(prf)
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
match self {
HkdfExpand::Reduced(hkdf) => hkdf.wants_flush(),
HkdfExpand::Normal(hkdf) => hkdf.wants_flush(),
}
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
match self {
HkdfExpand::Reduced(hkdf) => hkdf.flush(vm),
HkdfExpand::Normal(hkdf) => hkdf.flush(vm),
}
}
/// Sets the HKDF-Expand-Label context.
pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> {
match self {
HkdfExpand::Reduced(hkdf) => hkdf.set_ctx(ctx),
HkdfExpand::Normal(hkdf) => hkdf.set_ctx(ctx),
}
}
/// Whether the context has been set.
pub(crate) fn is_ctx_set(&self) -> bool {
match self {
HkdfExpand::Reduced(hkdf) => hkdf.is_ctx_set(),
HkdfExpand::Normal(hkdf) => hkdf.is_ctx_set(),
}
}
/// Returns the HKDF-Expand-Label output.
pub(crate) fn output(&self) -> Vector<U8> {
match self {
HkdfExpand::Reduced(hkdf) => hkdf.output(),
HkdfExpand::Normal(hkdf) => hkdf.output(),
}
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
match self {
HkdfExpand::Reduced(hkdf) => hkdf.is_complete(),
HkdfExpand::Normal(hkdf) => hkdf.is_complete(),
}
}
}
/// Computes `HKDF-Expand-Label` as defined in TLS 1.3.
pub(crate) fn hkdf_expand_label(key: &[u8], label: &[u8], ctx: &[u8], len: usize) -> Vec<u8> {
hkdf_expand(key, &make_hkdf_label(label, ctx, len), len)
}
/// Computes `HKDF-Expand` as defined in https://datatracker.ietf.org/doc/html/rfc5869
fn hkdf_expand(prk: &[u8], info: &[u8], len: usize) -> Vec<u8> {
assert!(len <= 32, "output length larger than 32 is not supported");
let mut info = info.to_vec();
info.push(0x01);
clear::hmac_sha256(prk, &info)[..len].to_vec()
}
#[cfg(test)]
mod tests {
use crate::{
hmac::{normal::HmacNormal, Hmac},
kdf::expand::{hkdf_expand_label, HkdfExpand},
test_utils::mock_vm,
Mode,
};
use mpz_common::context::test_st_context;
use mpz_vm_core::{
memory::{binary::Binary, MemoryExt, ViewExt},
Execute, Vm,
};
use rstest::*;
#[rstest]
#[case::normal(Mode::Normal)]
#[case::reduced(Mode::Reduced)]
#[tokio::test]
async fn test_hkdf_expand(#[case] mode: Mode) {
for fixture in test_fixtures() {
let (label, prk, ctx, output) = fixture;
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
fn setup_hkdf(
vm: &mut (dyn Vm<Binary> + Send),
prk: [u8; 32],
label: &'static [u8],
ctx: Option<&[u8]>,
ctx_len: usize,
out_len: usize,
mode: Mode,
) -> HkdfExpand {
let secret = vm.alloc_vec(32).unwrap();
vm.mark_public(secret).unwrap();
vm.assign(secret, prk.to_vec()).unwrap();
vm.commit(secret).unwrap();
let hmac = if mode == Mode::Normal {
Hmac::Normal(HmacNormal::alloc(vm, secret).unwrap())
} else {
use crate::hmac::reduced::HmacReduced;
Hmac::Reduced(HmacReduced::alloc(vm, secret).unwrap())
};
HkdfExpand::alloc(mode, vm, hmac, label, ctx, ctx_len, out_len).unwrap()
}
let mut hkdf_leader = setup_hkdf(
&mut leader,
prk.clone().try_into().unwrap(),
label,
Some(&ctx),
ctx.len(),
output.len(),
mode,
);
let mut hkdf_follower = setup_hkdf(
&mut follower,
prk.clone().try_into().unwrap(),
label,
Some(&ctx),
ctx.len(),
output.len(),
mode,
);
let out_leader = hkdf_leader.output();
let mut leader_decode_fut = leader.decode(out_leader).unwrap();
let out_follower = hkdf_follower.output();
let mut follower_decode_fut = follower.decode(out_follower).unwrap();
tokio::try_join!(
async {
leader.execute_all(&mut ctx_a).await.unwrap();
assert!(hkdf_leader.wants_flush());
hkdf_leader.flush(&mut leader).unwrap();
assert!(!hkdf_leader.wants_flush());
leader.execute_all(&mut ctx_a).await.unwrap();
Ok::<(), Box<dyn std::error::Error>>(())
},
async {
follower.execute_all(&mut ctx_b).await.unwrap();
assert!(hkdf_follower.wants_flush());
hkdf_follower.flush(&mut follower).unwrap();
assert!(!hkdf_follower.wants_flush());
follower.execute_all(&mut ctx_b).await.unwrap();
Ok::<(), Box<dyn std::error::Error>>(())
}
)
.unwrap();
let out_leader = leader_decode_fut.try_recv().unwrap().unwrap();
let out_follower = follower_decode_fut.try_recv().unwrap().unwrap();
assert_eq!(out_leader, out_follower);
assert_eq!(out_leader, output);
}
}
#[test]
fn test_hkdf_expand_label() {
for fixture in test_fixtures() {
let (label, prk, ctx, output) = fixture;
let out = hkdf_expand_label(&prk, label, &ctx, output.len());
assert_eq!(out, output);
}
}
// Reference values from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06
#[allow(clippy::type_complexity)]
fn test_fixtures() -> Vec<(&'static [u8], Vec<u8>, Vec<u8>, Vec<u8>)> {
vec![(
// LABEL
b"c hs traffic",
// PRK
from_hex_str("5b 4f 96 5d f0 3c 68 2c 46 e6 ee 86 c3 11 63 66 15 a1 d2 bb b2 43 45 c2 52 05 95 3c 87 9e 8d 06").to_vec(),
// CTX
from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d").to_vec(),
// OUTPUT
from_hex_str("e2 e2 32 07 bd 93 fb 7f e4 fc 2e 29 7a fe ab 16 0e 52 2b 5a b7 5d 64 a8 6e 75 bc ac 3f 3e 51 03").to_vec(),
),
(
// LABEL
b"s hs traffic",
// PRK
from_hex_str("5b 4f 96 5d f0 3c 68 2c 46 e6 ee 86 c3 11 63 66 15 a1 d2 bb b2 43 45 c2 52 05 95 3c 87 9e 8d 06").to_vec(),
// CTX
from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d").to_vec(),
// OUTPUT
from_hex_str("3b 7a 83 9c 23 9e f2 bf 0b 73 05 a0 e0 c4 e5 a8 c6 c6 93 30 a7 53 b3 08 f5 e3 a8 3a a2 ef 69 79").to_vec(),
),
(
// LABEL
b"c ap traffic",
// PRK
from_hex_str("5c 79 d1 69 42 4e 26 2b 56 32 03 62 7b e4 eb 51 03 3f 58 8c 43 c9 ce 03 73 37 2d bc bc 01 85 a7").to_vec(),
// CTX
from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf").to_vec(),
// OUTPUT
from_hex_str("e2 f0 db 6a 82 e8 82 80 fc 26 f7 3c 89 85 4e e8 61 5e 25 df 28 b2 20 79 62 fa 78 22 26 b2 36 26").to_vec(),
)
]
}
fn from_hex_str(s: &str) -> Vec<u8> {
hex::decode(s.split_whitespace().collect::<String>()).unwrap()
}
}

View File

@@ -0,0 +1,267 @@
//! Computation of HkdfLabel as specified in TLS 1.3.
use crate::FError;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
MemoryExt, Vector, ViewExt,
},
Vm,
};
/// Functionality for HkdfLabel computation.
#[derive(Debug)]
pub(crate) struct HkdfLabel {
/// Cleartext label.
label: HkdfLabelClear,
// VM reference for the HKDF label.
output: Vector<U8>,
// Label context.
ctx: Option<Vec<u8>>,
state: State,
}
impl HkdfLabel {
/// Allocates a new HkdfLabel.
pub(crate) fn alloc(
vm: &mut dyn Vm<Binary>,
label: &'static [u8],
ctx_len: usize,
out_len: usize,
) -> Result<Self, FError> {
let label_ref = vm
.alloc_vec::<U8>(hkdf_label_length(label.len(), ctx_len))
.map_err(FError::vm)?;
vm.mark_public(label_ref).map_err(FError::vm)?;
Ok(Self {
label: HkdfLabelClear::new(label, out_len),
output: label_ref,
ctx: None,
state: State::WantsContext,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
match self.state {
State::WantsContext => self.is_ctx_set(),
_ => false,
}
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
if let State::WantsContext = &mut self.state {
if let Some(ctx) = &self.ctx {
self.label.set_ctx(ctx)?;
vm.assign(self.output, self.label.output()?)
.map_err(FError::vm)?;
vm.commit(self.output).map_err(FError::vm)?;
self.state = State::Complete;
}
}
Ok(())
}
/// Sets label context.
pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> {
if self.is_ctx_set() {
return Err(FError::state("context has already been set"));
}
self.ctx = Some(ctx.to_vec());
Ok(())
}
/// Returns the HkdfLabel output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
/// Returns whether context has been set.
fn is_ctx_set(&self) -> bool {
self.ctx.is_some()
}
}
#[derive(Debug)]
enum State {
/// Wants the context to be set.
WantsContext,
Complete,
}
/// Functionality for HkdfLabel computation on cleartext values.
#[derive(Debug)]
pub(crate) struct HkdfLabelClear {
/// Human-readable label.
label: &'static [u8],
/// Context.
ctx: Option<Vec<u8>>,
/// Output length.
out_len: usize,
}
impl HkdfLabelClear {
/// Creates a new label.
pub(crate) fn new(label: &'static [u8], out_len: usize) -> Self {
Self {
label,
ctx: None,
out_len,
}
}
/// Sets label context.
pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> {
if self.ctx.is_some() {
return Err(FError::state("context has already been set"));
}
self.ctx = Some(ctx.to_vec());
Ok(())
}
/// Returns the byte representation of the label.
pub(crate) fn output(&self) -> Result<Vec<u8>, FError> {
match &self.ctx {
Some(ctx) => Ok(make_hkdf_label(self.label, ctx, self.out_len)),
_ => Err(FError::state("context was not set")),
}
}
}
/// Returns the byte representation of an HKDF label.
pub(crate) fn make_hkdf_label(label: &[u8], ctx: &[u8], out_len: usize) -> Vec<u8> {
assert!(
out_len <= 256,
"output length larger than 256 not supported"
);
const LABEL_PREFIX: &[u8] = b"tls13 ";
let mut hkdf_label = Vec::new();
let output_len = u16::to_be_bytes(out_len as u16);
let label_len = u8::to_be_bytes((LABEL_PREFIX.len() + label.len()) as u8);
let context_len = u8::to_be_bytes(ctx.len() as u8);
hkdf_label.extend_from_slice(&output_len);
hkdf_label.extend_from_slice(&label_len);
hkdf_label.extend_from_slice(LABEL_PREFIX);
hkdf_label.extend_from_slice(label);
hkdf_label.extend_from_slice(&context_len);
hkdf_label.extend_from_slice(ctx);
hkdf_label
}
/// Returns the length of an HKDF label.
fn hkdf_label_length(label_len: usize, ctx_len: usize) -> usize {
// 2 : output length as u16
// 1 : label length as u8
// 6 : length of "tls13 "
// 1 : context length as u8
// see `make_hkdf_label`
2 + 1 + 6 + label_len + 1 + ctx_len
}
#[cfg(test)]
mod tests {
use crate::kdf::expand::label::make_hkdf_label;
#[test]
fn test_make_hkdf_label() {
for fixture in test_fixtures() {
let (label, ctx, hkdf_label, out_len) = fixture;
assert_eq!(make_hkdf_label(label, &ctx, out_len), hkdf_label);
}
}
// Test vectors from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06
// (in that ref, `hash` is the context, `info` is the hkdf label).
#[allow(clippy::type_complexity)]
fn test_fixtures() -> Vec<(&'static [u8], Vec<u8>, Vec<u8>, usize)> {
vec![
(
b"derived",
from_hex_str("e3 b0 c4 42 98 fc 1c 14 9a fb f4 c8 99 6f b9 24 27 ae 41 e4 64 9b 93 4c a4 95 99 1b 78 52 b8 55"),
from_hex_str("00 20 0d 74 6c 73 31 33 20 64 65 72 69 76 65 64 20 e3 b0 c4 42 98 fc 1c 14 9a fb f4 c8 99 6f b9 24 27 ae 41 e4 64 9b 93 4c a4 95 99 1b 78 52 b8 55"),
32,
),
(
b"c hs traffic",
from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"),
from_hex_str("00 20 12 74 6c 73 31 33 20 63 20 68 73 20 74 72 61 66 66 69 63 20 c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"),
32,
),
(
b"s hs traffic",
from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"),
from_hex_str("00 20 12 74 6c 73 31 33 20 73 20 68 73 20 74 72 61 66 66 69 63 20 c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"),
32,
),
(
b"key",
from_hex_str(""),
from_hex_str("00 10 09 74 6c 73 31 33 20 6b 65 79 00"),
16,
),
(
b"iv",
from_hex_str(""),
from_hex_str("00 0c 08 74 6c 73 31 33 20 69 76 00"),
12,
),
(
b"finished",
from_hex_str(""),
from_hex_str("00 20 0e 74 6c 73 31 33 20 66 69 6e 69 73 68 65 64 00"),
32,
),
(
b"c ap traffic",
from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"),
from_hex_str("00 20 12 74 6c 73 31 33 20 63 20 61 70 20 74 72 61 66 66 69 63 20 f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"),
32,
),
(
b"s ap traffic",
from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"),
from_hex_str("00 20 12 74 6c 73 31 33 20 73 20 61 70 20 74 72 61 66 66 69 63 20 f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"),
32,
),
(
b"exp master",
from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"),
from_hex_str("00 20 10 74 6c 73 31 33 20 65 78 70 20 6d 61 73 74 65 72 20 f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"),
32,
),
(
b"res master",
from_hex_str("50 2f 86 b9 57 9e c0 53 d3 28 24 e2 78 0e f6 5c c4 37 a3 56 43 45 35 6b df 79 13 ec 3b 87 96 14"),
from_hex_str("00 20 10 74 6c 73 31 33 20 72 65 73 20 6d 61 73 74 65 72 20 50 2f 86 b9 57 9e c0 53 d3 28 24 e2 78 0e f6 5c c4 37 a3 56 43 45 35 6b df 79 13 ec 3b 87 96 14"),
32,
),
(
b"resumption",
from_hex_str("00 00"),
from_hex_str("00 20 10 74 6c 73 31 33 20 72 65 73 75 6d 70 74 69 6f 6e 02 00 00"),
32,
),
]
}
fn from_hex_str(s: &str) -> Vec<u8> {
hex::decode(s.split_whitespace().collect::<String>()).unwrap()
}
}

View File

@@ -0,0 +1,134 @@
use crate::{
hmac::normal::HmacNormal, kdf::expand::label::HkdfLabel, tls12::merge_vectors, FError,
};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
MemoryExt, Vector, ViewExt,
},
Vm,
};
#[derive(Debug)]
enum State {
/// Wants the context to be set.
WantsContext,
/// Context has been set.
ContextSet,
Complete,
}
/// Functionality for computing `HKDF-Expand-Label` with a private secret
/// and public label and context.
#[derive(Debug)]
pub(crate) struct HkdfExpand {
label: HkdfLabel,
state: State,
ctx: Option<Vec<u8>>,
output: Vector<U8>,
}
impl HkdfExpand {
/// Allocates a new HKDF-Expand-Label with the `hmac`
/// instantiated with the secret.
pub(crate) fn alloc(
vm: &mut dyn Vm<Binary>,
mut hmac: HmacNormal,
// Human-readable label.
label: &'static [u8],
// Context length.
ctx_len: usize,
// Output length.
out_len: usize,
) -> Result<Self, FError> {
assert!(
out_len <= 32,
"output length larger than 32 is not supported"
);
let hkdf_label = HkdfLabel::alloc(vm, label, ctx_len, out_len)?;
let info = hkdf_label.output();
// HKDF-Expand requires 0x01 to be concatenated.
// see line: T(1) = HMAC-Hash(PRK, T(0) | info | 0x01) in
// https://datatracker.ietf.org/doc/html/rfc5869
let constant = vm.alloc_vec::<U8>(1).map_err(FError::vm)?;
vm.mark_public(constant).map_err(FError::vm)?;
vm.assign(constant, vec![0x01]).map_err(FError::vm)?;
vm.commit(constant).map_err(FError::vm)?;
let msg = merge_vectors(vm, vec![info, constant], info.len() + constant.len())?;
hmac.set_msg(vm, &[msg])?;
let mut output: Vector<U8> = hmac.output()?.into();
output.truncate(out_len);
Ok(Self {
output,
label: hkdf_label,
ctx: None,
state: State::WantsContext,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
let state_wants_flush = match self.state {
State::WantsContext => self.is_ctx_set(),
_ => false,
};
state_wants_flush || self.label.wants_flush()
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
self.label.flush(vm)?;
match &mut self.state {
State::WantsContext => {
if let Some(ctx) = &self.ctx {
self.label.set_ctx(ctx)?;
self.label.flush(vm)?;
self.state = State::ContextSet;
// Recurse.
self.flush(vm)?;
}
}
State::ContextSet => {
if self.label.is_complete() {
self.state = State::Complete;
}
}
_ => (),
}
Ok(())
}
/// Sets the HKDF-Expand-Label context.
pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> {
if self.is_ctx_set() {
return Err(FError::state("context has already been set"));
}
self.ctx = Some(ctx.to_vec());
Ok(())
}
/// Returns the HKDF-Expand-Label output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
/// Whether the context has been set.
pub(crate) fn is_ctx_set(&self) -> bool {
self.ctx.is_some()
}
}

View File

@@ -0,0 +1,125 @@
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Vector,
},
Vm,
};
use crate::{hmac::reduced::HmacReduced, kdf::expand::label::HkdfLabelClear, FError};
/// Functionality for computing `HKDF-Expand-Label` with a private secret
/// and public label and context.
#[derive(Debug)]
pub(crate) struct HkdfExpand {
label: HkdfLabelClear,
hmac: HmacReduced,
ctx: Option<Vec<u8>>,
output: Vector<U8>,
state: State,
}
impl HkdfExpand {
/// Allocates a new HKDF-Expand-Label with the `hmac`
/// instantiated with the secret.
pub(crate) fn alloc(
hmac: HmacReduced,
// Human-readable label.
label: &'static [u8],
// Output length.
out_len: usize,
) -> Result<Self, FError> {
assert!(
out_len <= 32,
"output length larger than 32 is not supported"
);
let hkdf_label = HkdfLabelClear::new(label, out_len);
let mut output: Vector<U8> = hmac.output().into();
output.truncate(out_len);
Ok(Self {
label: hkdf_label,
hmac,
ctx: None,
output,
state: State::WantsContext,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
let state_wants_flush = match self.state {
State::WantsContext => self.is_ctx_set(),
_ => false,
};
state_wants_flush || self.hmac.wants_flush()
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
self.hmac.flush(vm)?;
match self.state {
State::WantsContext => {
if let Some(ctx) = &self.ctx {
// HKDF-Expand requires 0x01 to be concatenated.
// see line: T(1) = HMAC-Hash(PRK, T(0) | info | 0x01) in
// https://datatracker.ietf.org/doc/html/rfc5869
self.label.set_ctx(ctx)?;
let mut label = self.label.output()?;
label.push(0x01);
self.hmac.set_msg(&label)?;
self.hmac.flush(vm)?;
self.state = State::ContextSet;
// Recurse.
self.flush(vm)?;
}
}
State::ContextSet => {
if self.hmac.is_complete() {
self.state = State::Complete;
}
}
_ => (),
}
Ok(())
}
/// Sets the HKDF-Expand-Label context.
pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> {
if self.is_ctx_set() {
return Err(FError::state("context has already been set"));
}
self.ctx = Some(ctx.to_vec());
Ok(())
}
/// Returns the HKDF-Expand-Label output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
/// Whether the context has been set.
pub(crate) fn is_ctx_set(&self) -> bool {
self.ctx.is_some()
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
#[derive(Debug)]
enum State {
WantsContext,
ContextSet,
Complete,
}

View File

@@ -0,0 +1,334 @@
//! `HKDF-Extract` function as defined in https://datatracker.ietf.org/doc/html/rfc5869
use crate::{
hmac::{normal::HmacNormal, Hmac},
FError, Mode,
};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, Vector,
},
Vm,
};
pub(crate) mod normal;
pub(crate) mod reduced;
/// Functionality for computing `HKDF-Extract` with private salt and public
/// IKM.
#[derive(Debug)]
pub(crate) enum HkdfExtract {
Reduced(reduced::HkdfExtract),
Normal(normal::HkdfExtract),
}
impl HkdfExtract {
/// Allocates a new HKDF-Extract with the given `ikm` and `hmac`
/// instantiated with the salt.
pub(crate) fn alloc(
mode: Mode,
vm: &mut dyn Vm<Binary>,
ikm: [u8; 32],
hmac: Hmac,
) -> Result<Self, FError> {
let prf = match mode {
Mode::Reduced => {
if let Hmac::Reduced(hmac) = hmac {
Self::Reduced(reduced::HkdfExtract::alloc(ikm, hmac)?)
} else {
unreachable!("modes always match");
}
}
Mode::Normal => {
if let Hmac::Normal(hmac) = hmac {
Self::Normal(normal::HkdfExtract::alloc(vm, ikm, hmac)?)
} else {
unreachable!("modes always match");
}
}
};
Ok(prf)
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
match self {
HkdfExtract::Reduced(hkdf) => hkdf.wants_flush(),
HkdfExtract::Normal(hkdf) => hkdf.wants_flush(),
}
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
match self {
HkdfExtract::Reduced(hkdf) => hkdf.flush(vm),
HkdfExtract::Normal(hkdf) => hkdf.flush(),
}
}
/// Returns HKDF-Extract output.
pub(crate) fn output(&self) -> Vector<U8> {
match self {
HkdfExtract::Reduced(hkdf) => hkdf.output(),
HkdfExtract::Normal(hkdf) => hkdf.output(),
}
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
match self {
HkdfExtract::Reduced(hkdf) => hkdf.is_complete(),
HkdfExtract::Normal(hkdf) => hkdf.is_complete(),
}
}
}
/// Functionality for computing `HKDF-Extract` with private IKM and public
/// salt.
#[derive(Debug)]
pub(crate) struct HkdfExtractPrivIkm {
output: Vector<U8>,
state: State,
}
impl HkdfExtractPrivIkm {
/// Allocates a new HKDF-Extract with the given `ikm` and `hmac`
/// instantiated with the salt.
pub(crate) fn alloc(
vm: &mut dyn Vm<Binary>,
ikm: Array<U8, 32>,
mut hmac: HmacNormal,
) -> Result<Self, FError> {
hmac.set_msg(vm, &[ikm.into()])?;
Ok(Self {
output: hmac.output()?.into(),
state: State::Setup,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
matches!(self.state, State::Setup)
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self) {
if let State::Setup = self.state {
self.state = State::Complete;
}
}
/// Returns HKDF-Extract output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Setup,
Complete,
}
#[cfg(test)]
mod tests {
use crate::{
hmac::{clear, normal::HmacNormal, Hmac},
kdf::extract::{HkdfExtract, HkdfExtractPrivIkm},
test_utils::mock_vm,
Mode,
};
use mpz_common::context::test_st_context;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
};
use rstest::*;
#[tokio::test]
async fn test_hkdf_extract_priv_ikm() {
for fixture in test_fixtures() {
let (salt, ikm, secret) = fixture;
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let ikm: [u8; 32] = ikm.try_into().unwrap();
let inner_state = clear::compute_inner_partial(&salt);
let outer_state = clear::compute_outer_partial(&salt);
// ------------------ LEADER
let vm = &mut leader;
let ikm_ref: Array<U8, 32> = vm.alloc().unwrap();
vm.mark_public(ikm_ref).unwrap();
vm.assign(ikm_ref, ikm).unwrap();
vm.commit(ikm_ref).unwrap();
let hmac = HmacNormal::alloc_with_state(vm, inner_state, outer_state).unwrap();
let mut hkdf_leader = HkdfExtractPrivIkm::alloc(vm, ikm_ref, hmac).unwrap();
let out_leader = hkdf_leader.output();
let mut leader_decode_fut = vm.decode(out_leader).unwrap();
// ------------------ FOLLOWER
let vm = &mut follower;
let ikm_ref: Array<U8, 32> = vm.alloc().unwrap();
vm.mark_public(ikm_ref).unwrap();
vm.assign(ikm_ref, ikm).unwrap();
vm.commit(ikm_ref).unwrap();
let hmac = HmacNormal::alloc_with_state(vm, inner_state, outer_state).unwrap();
let mut hkdf_follower = HkdfExtractPrivIkm::alloc(vm, ikm_ref, hmac).unwrap();
let out_follower = hkdf_follower.output();
let mut follower_decode_fut = vm.decode(out_follower).unwrap();
tokio::try_join!(
async {
leader.execute_all(&mut ctx_a).await.unwrap();
assert!(hkdf_leader.wants_flush());
hkdf_leader.flush();
assert!(!hkdf_leader.wants_flush());
Ok::<(), Box<dyn std::error::Error>>(())
},
async {
follower.execute_all(&mut ctx_b).await.unwrap();
assert!(hkdf_follower.wants_flush());
hkdf_follower.flush();
assert!(!hkdf_follower.wants_flush());
Ok::<(), Box<dyn std::error::Error>>(())
}
)
.unwrap();
let leader_out = leader_decode_fut.try_recv().unwrap().unwrap();
let follower_out = follower_decode_fut.try_recv().unwrap().unwrap();
assert_eq!(leader_out, follower_out);
assert_eq!(leader_out, secret);
}
}
#[rstest]
#[case::normal(Mode::Normal)]
#[case::reduced(Mode::Reduced)]
#[tokio::test]
async fn test_hkdf_extract(#[case] mode: Mode) {
for fixture in test_fixtures() {
let (salt, ikm, secret) = fixture;
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let salt: [u8; 32] = salt.try_into().unwrap();
// ------------------ LEADER
let vm = &mut leader;
let salt_ref = vm.alloc_vec(32).unwrap();
vm.mark_public(salt_ref).unwrap();
vm.assign(salt_ref, salt.to_vec()).unwrap();
vm.commit(salt_ref).unwrap();
let hmac = Hmac::alloc(vm, salt_ref, mode).unwrap();
let mut hkdf_leader =
HkdfExtract::alloc(mode, vm, ikm.clone().try_into().unwrap(), hmac).unwrap();
let out_leader = hkdf_leader.output();
let mut leader_decode_fut = leader.decode(out_leader).unwrap();
// ------------------ FOLLOWER
let vm = &mut follower;
let salt_ref = vm.alloc_vec(32).unwrap();
vm.mark_public(salt_ref).unwrap();
vm.assign(salt_ref, salt.to_vec()).unwrap();
vm.commit(salt_ref).unwrap();
let hmac = Hmac::alloc(vm, salt_ref, mode).unwrap();
let mut hkdf_follower =
HkdfExtract::alloc(mode, vm, ikm.try_into().unwrap(), hmac).unwrap();
let out_follower = hkdf_follower.output();
let mut follower_decode_fut = follower.decode(out_follower).unwrap();
tokio::try_join!(
async {
leader.execute_all(&mut ctx_a).await.unwrap();
assert!(hkdf_leader.wants_flush());
hkdf_leader.flush(&mut leader).unwrap();
assert!(!hkdf_leader.wants_flush());
leader.execute_all(&mut ctx_a).await.unwrap();
Ok::<(), Box<dyn std::error::Error>>(())
},
async {
follower.execute_all(&mut ctx_b).await.unwrap();
assert!(hkdf_follower.wants_flush());
hkdf_follower.flush(&mut follower).unwrap();
assert!(!hkdf_follower.wants_flush());
follower.execute_all(&mut ctx_b).await.unwrap();
Ok::<(), Box<dyn std::error::Error>>(())
}
)
.unwrap();
let out_leader = leader_decode_fut.try_recv().unwrap().unwrap();
let out_follower = follower_decode_fut.try_recv().unwrap().unwrap();
assert_eq!(out_leader, out_follower);
assert_eq!(out_leader, secret);
}
}
// Reference values from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06
fn test_fixtures() -> Vec<(Vec<u8>, Vec<u8>, Vec<u8>)> {
vec![(
// SALT
from_hex_str::<32>("6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba").to_vec(),
// IKM
from_hex_str::<32>("81 51 d1 46 4c 1b 55 53 36 23 b9 c2 24 6a 6a 0e 6e 7e 18 50 63 e1 4a fd af f0 b6 e1 c6 1a 86 42").to_vec(),
// SECRET
from_hex_str::<32>("5b 4f 96 5d f0 3c 68 2c 46 e6 ee 86 c3 11 63 66 15 a1 d2 bb b2 43 45 c2 52 05 95 3c 87 9e 8d 06").to_vec(),
),
(
// SALT
from_hex_str::<32>("c8 61 57 19 e2 40 37 47 b6 10 76 2c 72 b8 f4 da 5c 60 99 57 65 d4 04 a9 d0 06 b9 b0 72 7b a5 83").to_vec(),
// IKM
from_hex_str::<32>("00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00").to_vec(),
// SECRET
from_hex_str::<32>("5c 79 d1 69 42 4e 26 2b 56 32 03 62 7b e4 eb 51 03 3f 58 8c 43 c9 ce 03 73 37 2d bc bc 01 85 a7").to_vec(),
),
(
// SALT
from_hex_str::<32>("9e fc 79 87 0b 08 c4 c6 51 20 52 50 af 9b 83 04 79 11 b7 83 d5 d7 67 8d 7c cc e7 18 18 9e a2 ec").to_vec(),
// IKM
from_hex_str::<32>("b0 66 a1 5b c1 aa ee f8 79 0e 0b 02 e6 2f 82 dc 44 64 46 e3 7d 6d 61 22 b0 d3 b9 94 ef 11 dd 3c").to_vec(),
// SECRET
from_hex_str::<32>("ea d8 b8 c5 9a 15 df 29 d7 9f a4 ac 31 d5 f7 c9 0e 2e 5c 87 d9 ea fe d1 fe 69 16 cf 2f 29 37 34").to_vec(),
)
]
}
fn from_hex_str<const N: usize>(s: &str) -> [u8; N] {
let bytes: Vec<u8> = hex::decode(s.split_whitespace().collect::<String>()).unwrap();
bytes
.try_into()
.expect("Hex string length does not match array size")
}
}

View File

@@ -0,0 +1,76 @@
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, MemoryExt, Vector, ViewExt,
},
Vm,
};
use crate::{hmac::normal::HmacNormal, FError};
/// Functionality for HKDF-Extract computation with private salt and public
/// IKM.
#[derive(Debug)]
pub(crate) struct HkdfExtract {
hmac: HmacNormal,
output: Vector<U8>,
state: State,
}
impl HkdfExtract {
/// Allocates a new HKDF-Extract with the given `ikm` and `hmac`
/// instantiated with the salt.
pub(crate) fn alloc(
vm: &mut dyn Vm<Binary>,
ikm: [u8; 32],
mut hmac: HmacNormal,
) -> Result<Self, FError> {
let msg: Array<U8, 32> = vm.alloc().map_err(FError::vm)?;
vm.mark_public(msg).map_err(FError::vm)?;
vm.assign(msg, ikm).map_err(FError::vm)?;
vm.commit(msg).map_err(FError::vm)?;
hmac.set_msg(vm, &[msg.into()])?;
Ok(Self {
output: hmac.output()?.into(),
hmac,
state: State::Setup {},
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
matches!(self.state, State::Setup) || self.hmac.wants_flush()
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self) -> Result<(), FError> {
self.hmac.flush()?;
if let State::Setup = &mut self.state {
if self.hmac.is_complete() {
self.state = State::Complete;
}
}
Ok(())
}
/// Returns HKDF-Extract output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Setup,
Complete,
}

View File

@@ -0,0 +1,67 @@
use crate::{hmac::reduced::HmacReduced, FError};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Vector,
},
Vm,
};
/// Functionality for HKDF-Extract computation with private salt and public
/// IKM.
#[derive(Debug)]
pub(crate) struct HkdfExtract {
hmac: HmacReduced,
output: Vector<U8>,
state: State,
}
impl HkdfExtract {
/// Allocates a new HKDF-Extract with the given `ikm` and `hmac`
/// instantiated with the salt.
pub(crate) fn alloc(ikm: [u8; 32], mut hmac: HmacReduced) -> Result<Self, FError> {
hmac.set_msg(&ikm)?;
Ok(Self {
output: hmac.output().into(),
hmac,
state: State::Setup,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
matches!(self.state, State::Setup) || self.hmac.wants_flush()
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
self.hmac.flush(vm)?;
if let State::Setup = &mut self.state {
if self.hmac.is_complete() {
self.state = State::Complete;
}
}
Ok(())
}
/// Returns HKDF-Extract output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Setup,
Complete,
}

View File

@@ -0,0 +1,2 @@
pub(crate) mod expand;
pub(crate) mod extract;

View File

@@ -1,4 +1,5 @@
//! This crate contains the protocol for computing TLS 1.2 SHA-256 HMAC PRF.
//! MPC protocols for computing HMAC-SHA-256-based PRF for TLS 1.2 and key
//! schedule for TLS 1.3.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
@@ -12,36 +13,15 @@ mod config;
pub use config::Mode;
mod error;
pub use error::PrfError;
pub use error::FError;
mod kdf;
mod prf;
pub use prf::MpcPrf;
mod tls12;
mod tls13;
use mpz_vm_core::memory::{binary::U8, Array};
/// PRF output.
#[derive(Debug, Clone, Copy)]
pub struct PrfOutput {
/// TLS session keys.
pub keys: SessionKeys,
/// Client finished verify data.
pub cf_vd: Array<U8, 12>,
/// Server finished verify data.
pub sf_vd: Array<U8, 12>,
}
/// Session keys computed by the PRF.
#[derive(Debug, Clone, Copy)]
pub struct SessionKeys {
/// Client write key.
pub client_write_key: Array<U8, 16>,
/// Server write key.
pub server_write_key: Array<U8, 16>,
/// Client IV.
pub client_iv: Array<U8, 4>,
/// Server IV.
pub server_iv: Array<U8, 4>,
}
pub use tls12::{PrfOutput, SessionKeys, Tls12Prf};
pub use tls13::{ApplicationKeys, HandshakeKeys, Role, Tls13KeySched};
fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] {
use sha2::{
@@ -60,6 +40,20 @@ fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] {
state
}
pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] {
use sha2::{
compress256,
digest::{
block_buffer::{BlockBuffer, Eager},
generic_array::typenum::U64,
},
};
let mut buffer = BlockBuffer::<U64, Eager>::default();
buffer.digest_blocks(msg, |b| compress256(&mut state, b));
state
}
fn state_to_bytes(input: [u32; 8]) -> [u8; 32] {
let mut output = [0_u8; 32];
for (k, byte_chunk) in input.iter().enumerate() {
@@ -68,204 +62,3 @@ fn state_to_bytes(input: [u32; 8]) -> [u8; 32] {
}
output
}
#[cfg(test)]
mod tests {
use crate::{
test_utils::{prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
Mode, MpcPrf, SessionKeys,
};
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
#[tokio::test]
async fn test_prf_reduced() {
let mode = Mode::Reduced;
test_prf(mode).await;
}
#[tokio::test]
async fn test_prf_normal() {
let mode = Mode::Normal;
test_prf(mode).await;
}
async fn test_prf(mode: Mode) {
let mut rng = StdRng::seed_from_u64(1);
// Test input
let pms: [u8; 32] = rng.random();
let client_random: [u8; 32] = rng.random();
let server_random: [u8; 32] = rng.random();
let cf_hs_hash: [u8; 32] = rng.random();
let sf_hs_hash: [u8; 32] = rng.random();
// Expected output
let ms_expected = prf_ms(pms, client_random, server_random);
let [cwk_expected, swk_expected, civ_expected, siv_expected] =
prf_keys(ms_expected, client_random, server_random);
let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap();
let swk_expected: [u8; 16] = swk_expected.try_into().unwrap();
let civ_expected: [u8; 4] = civ_expected.try_into().unwrap();
let siv_expected: [u8; 4] = siv_expected.try_into().unwrap();
let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash);
let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash);
let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap();
let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap();
// Set up vm and prf
let (mut ctx_a, mut ctx_b) = test_st_context(128);
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let leader_pms: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(leader_pms).unwrap();
leader.assign(leader_pms, pms).unwrap();
leader.commit(leader_pms).unwrap();
let follower_pms: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(follower_pms).unwrap();
follower.assign(follower_pms, pms).unwrap();
follower.commit(follower_pms).unwrap();
let mut prf_leader = MpcPrf::new(mode);
let mut prf_follower = MpcPrf::new(mode);
let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap();
let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap();
// client_random and server_random
prf_leader.set_client_random(client_random).unwrap();
prf_follower.set_client_random(client_random).unwrap();
prf_leader.set_server_random(server_random).unwrap();
prf_follower.set_server_random(server_random).unwrap();
let SessionKeys {
client_write_key: cwk_leader,
server_write_key: swk_leader,
client_iv: civ_leader,
server_iv: siv_leader,
} = leader_prf_out.keys;
let mut cwk_leader = leader.decode(cwk_leader).unwrap();
let mut swk_leader = leader.decode(swk_leader).unwrap();
let mut civ_leader = leader.decode(civ_leader).unwrap();
let mut siv_leader = leader.decode(siv_leader).unwrap();
let SessionKeys {
client_write_key: cwk_follower,
server_write_key: swk_follower,
client_iv: civ_follower,
server_iv: siv_follower,
} = follower_prf_out.keys;
let mut cwk_follower = follower.decode(cwk_follower).unwrap();
let mut swk_follower = follower.decode(swk_follower).unwrap();
let mut civ_follower = follower.decode(civ_follower).unwrap();
let mut siv_follower = follower.decode(siv_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
let cwk_leader = cwk_leader.try_recv().unwrap().unwrap();
let swk_leader = swk_leader.try_recv().unwrap().unwrap();
let civ_leader = civ_leader.try_recv().unwrap().unwrap();
let siv_leader = siv_leader.try_recv().unwrap().unwrap();
let cwk_follower = cwk_follower.try_recv().unwrap().unwrap();
let swk_follower = swk_follower.try_recv().unwrap().unwrap();
let civ_follower = civ_follower.try_recv().unwrap().unwrap();
let siv_follower = siv_follower.try_recv().unwrap().unwrap();
assert_eq!(cwk_leader, cwk_follower);
assert_eq!(swk_leader, swk_follower);
assert_eq!(civ_leader, civ_follower);
assert_eq!(siv_leader, siv_follower);
assert_eq!(cwk_leader, cwk_expected);
assert_eq!(swk_leader, swk_expected);
assert_eq!(civ_leader, civ_expected);
assert_eq!(siv_leader, siv_expected);
// client finished
prf_leader.set_cf_hash(cf_hs_hash).unwrap();
prf_follower.set_cf_hash(cf_hs_hash).unwrap();
let cf_vd_leader = leader_prf_out.cf_vd;
let cf_vd_follower = follower_prf_out.cf_vd;
let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap();
let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap();
let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap();
assert_eq!(cf_vd_leader, cf_vd_follower);
assert_eq!(cf_vd_leader, cf_vd_expected);
// server finished
prf_leader.set_sf_hash(sf_hs_hash).unwrap();
prf_follower.set_sf_hash(sf_hs_hash).unwrap();
let sf_vd_leader = leader_prf_out.sf_vd;
let sf_vd_follower = follower_prf_out.sf_vd;
let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap();
let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap();
let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap();
assert_eq!(sf_vd_leader, sf_vd_follower);
assert_eq!(sf_vd_leader, sf_vd_expected);
}
}

View File

@@ -1,409 +0,0 @@
use crate::{
hmac::{IPAD, OPAD},
Mode, PrfError, PrfOutput,
};
use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder};
use mpz_hash::sha256::Sha256;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, MemoryExt, StaticSize, Vector, ViewExt,
},
Call, CallableExt, Vm,
};
use std::{fmt::Debug, sync::Arc};
use tracing::instrument;
mod state;
use state::State;
mod function;
use function::Prf;
/// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF.
#[derive(Debug)]
pub struct MpcPrf {
mode: Mode,
state: State,
}
impl MpcPrf {
/// Creates a new instance of the PRF.
///
/// # Arguments
///
/// `mode` - The PRF mode.
pub fn new(mode: Mode) -> MpcPrf {
Self {
mode,
state: State::Initialized,
}
}
/// Allocates resources for the PRF.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `pms` - The pre-master secret.
#[instrument(level = "debug", skip_all, err)]
pub fn alloc(
&mut self,
vm: &mut dyn Vm<Binary>,
pms: Array<U8, 32>,
) -> Result<PrfOutput, PrfError> {
let State::Initialized = self.state.take() else {
return Err(PrfError::state("PRF not in initialized state"));
};
let mode = self.mode;
let pms: Vector<U8> = pms.into();
let outer_partial_pms = compute_partial(vm, pms, OPAD)?;
let inner_partial_pms = compute_partial(vm, pms, IPAD)?;
let master_secret =
Prf::alloc_master_secret(mode, vm, outer_partial_pms, inner_partial_pms)?;
let ms = master_secret.output();
let ms = merge_outputs(vm, ms, 48)?;
let outer_partial_ms = compute_partial(vm, ms, OPAD)?;
let inner_partial_ms = compute_partial(vm, ms, IPAD)?;
let key_expansion =
Prf::alloc_key_expansion(mode, vm, outer_partial_ms.clone(), inner_partial_ms.clone())?;
let client_finished = Prf::alloc_client_finished(
mode,
vm,
outer_partial_ms.clone(),
inner_partial_ms.clone(),
)?;
let server_finished = Prf::alloc_server_finished(
mode,
vm,
outer_partial_ms.clone(),
inner_partial_ms.clone(),
)?;
self.state = State::SessionKeys {
client_random: None,
master_secret,
key_expansion,
client_finished,
server_finished,
};
self.state.prf_output(vm)
}
/// Sets the client random.
///
/// # Arguments
///
/// * `random` - The client random.
#[instrument(level = "debug", skip_all, err)]
pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> {
let State::SessionKeys { client_random, .. } = &mut self.state else {
return Err(PrfError::state("PRF not set up"));
};
*client_random = Some(random);
Ok(())
}
/// Sets the server random.
///
/// # Arguments
///
/// * `random` - The server random.
#[instrument(level = "debug", skip_all, err)]
pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> {
let State::SessionKeys {
client_random,
master_secret,
key_expansion,
..
} = &mut self.state
else {
return Err(PrfError::state("PRF not set up"));
};
let client_random = client_random.expect("Client random should have been set by now");
let server_random = random;
let mut seed_ms = client_random.to_vec();
seed_ms.extend_from_slice(&server_random);
master_secret.set_start_seed(seed_ms);
let mut seed_ke = server_random.to_vec();
seed_ke.extend_from_slice(&client_random);
key_expansion.set_start_seed(seed_ke);
Ok(())
}
/// Sets the client finished handshake hash.
///
/// # Arguments
///
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> {
let State::ClientFinished {
client_finished, ..
} = &mut self.state
else {
return Err(PrfError::state("PRF not in client finished state"));
};
let seed_cf = handshake_hash.to_vec();
client_finished.set_start_seed(seed_cf);
Ok(())
}
/// Sets the server finished handshake hash.
///
/// # Arguments
///
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> {
let State::ServerFinished { server_finished } = &mut self.state else {
return Err(PrfError::state("PRF not in server finished state"));
};
let seed_sf = handshake_hash.to_vec();
server_finished.set_start_seed(seed_sf);
Ok(())
}
/// Returns if the PRF needs to be flushed.
pub fn wants_flush(&self) -> bool {
match &self.state {
State::Initialized => false,
State::SessionKeys {
master_secret,
key_expansion,
..
} => master_secret.wants_flush() || key_expansion.wants_flush(),
State::ClientFinished {
client_finished, ..
} => client_finished.wants_flush(),
State::ServerFinished { server_finished } => server_finished.wants_flush(),
State::Complete => false,
State::Error => false,
}
}
/// Flushes the PRF.
pub fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
self.state = match self.state.take() {
State::SessionKeys {
client_random,
mut master_secret,
mut key_expansion,
client_finished,
server_finished,
} => {
master_secret.flush(vm)?;
key_expansion.flush(vm)?;
if !master_secret.wants_flush() && !key_expansion.wants_flush() {
State::ClientFinished {
client_finished,
server_finished,
}
} else {
State::SessionKeys {
client_random,
master_secret,
key_expansion,
client_finished,
server_finished,
}
}
}
State::ClientFinished {
mut client_finished,
server_finished,
} => {
client_finished.flush(vm)?;
if !client_finished.wants_flush() {
State::ServerFinished { server_finished }
} else {
State::ClientFinished {
client_finished,
server_finished,
}
}
}
State::ServerFinished {
mut server_finished,
} => {
server_finished.flush(vm)?;
if !server_finished.wants_flush() {
State::Complete
} else {
State::ServerFinished { server_finished }
}
}
other => other,
};
Ok(())
}
}
/// Depending on the provided `mask` computes and returns `outer_partial` or
/// `inner_partial` for HMAC-SHA256.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `key` - Key to pad and xor.
/// * `mask`- Mask used for padding.
fn compute_partial(
vm: &mut dyn Vm<Binary>,
key: Vector<U8>,
mask: [u8; 64],
) -> Result<Sha256, PrfError> {
let xor = Arc::new(xor(8 * 64));
let additional_len = 64 - key.len();
let padding = vec![0_u8; additional_len];
let padding_ref: Vector<U8> = vm.alloc_vec(additional_len).map_err(PrfError::vm)?;
vm.mark_public(padding_ref).map_err(PrfError::vm)?;
vm.assign(padding_ref, padding).map_err(PrfError::vm)?;
vm.commit(padding_ref).map_err(PrfError::vm)?;
let mask_ref: Array<U8, 64> = vm.alloc().map_err(PrfError::vm)?;
vm.mark_public(mask_ref).map_err(PrfError::vm)?;
vm.assign(mask_ref, mask).map_err(PrfError::vm)?;
vm.commit(mask_ref).map_err(PrfError::vm)?;
let xor = Call::builder(xor)
.arg(key)
.arg(padding_ref)
.arg(mask_ref)
.build()
.map_err(PrfError::vm)?;
let key_padded: Vector<U8> = vm.call(xor).map_err(PrfError::vm)?;
let mut sha = Sha256::new_with_init(vm)?;
sha.update(&key_padded);
sha.compress(vm)?;
Ok(sha)
}
fn merge_outputs(
vm: &mut dyn Vm<Binary>,
inputs: Vec<Array<U8, 32>>,
output_bytes: usize,
) -> Result<Vector<U8>, PrfError> {
assert!(output_bytes <= 32 * inputs.len());
let bits = Array::<U8, 32>::SIZE * inputs.len();
let circ = gen_merge_circ(bits);
let mut builder = Call::builder(circ);
for &input in inputs.iter() {
builder = builder.arg(input);
}
let call = builder.build().map_err(PrfError::vm)?;
let mut output: Vector<U8> = vm.call(call).map_err(PrfError::vm)?;
output.truncate(output_bytes);
Ok(output)
}
fn gen_merge_circ(size: usize) -> Arc<Circuit> {
let mut builder = CircuitBuilder::new();
let inputs = (0..size).map(|_| builder.add_input()).collect::<Vec<_>>();
for input in inputs.chunks_exact(8) {
for byte in input.chunks_exact(8) {
for &feed in byte.iter() {
let output = builder.add_id_gate(feed);
builder.add_output(output);
}
}
}
Arc::new(builder.build().expect("merge circuit is valid"))
}
#[cfg(test)]
mod tests {
use crate::prf::merge_outputs;
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
};
#[tokio::test]
async fn test_merge_outputs() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let input1: [u8; 32] = std::array::from_fn(|i| i as u8);
let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32);
let mut expected = input1.to_vec();
expected.extend_from_slice(&input2);
expected.truncate(48);
// leader
let input1_leader: Array<U8, 32> = leader.alloc().unwrap();
let input2_leader: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(input1_leader).unwrap();
leader.mark_public(input2_leader).unwrap();
leader.assign(input1_leader, input1).unwrap();
leader.assign(input2_leader, input2).unwrap();
leader.commit(input1_leader).unwrap();
leader.commit(input2_leader).unwrap();
let merged_leader =
merge_outputs(&mut leader, vec![input1_leader, input2_leader], 48).unwrap();
let mut merged_leader = leader.decode(merged_leader).unwrap();
// follower
let input1_follower: Array<U8, 32> = follower.alloc().unwrap();
let input2_follower: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(input1_follower).unwrap();
follower.mark_public(input2_follower).unwrap();
follower.assign(input1_follower, input1).unwrap();
follower.assign(input2_follower, input2).unwrap();
follower.commit(input1_follower).unwrap();
follower.commit(input2_follower).unwrap();
let merged_follower =
merge_outputs(&mut follower, vec![input1_follower, input2_follower], 48).unwrap();
let mut merged_follower = follower.decode(merged_follower).unwrap();
tokio::try_join!(
leader.execute_all(&mut ctx_a),
follower.execute_all(&mut ctx_b)
)
.unwrap();
let merged_leader = merged_leader.try_recv().unwrap().unwrap();
let merged_follower = merged_follower.try_recv().unwrap().unwrap();
assert_eq!(merged_leader, merged_follower);
assert_eq!(merged_leader, expected);
}
}

View File

@@ -1,11 +1,10 @@
//! Provides [`Prf`], for computing the TLS 1.2 PRF.
use crate::{Mode, PrfError};
use mpz_hash::sha256::Sha256;
use crate::{hmac::Hmac, FError, Mode};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array,
Vector,
},
Vm,
};
@@ -20,90 +19,107 @@ pub(crate) enum Prf {
}
impl Prf {
/// Allocates master secret.
pub(crate) fn alloc_master_secret(
mode: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
hmac: Hmac,
) -> Result<Self, FError> {
let prf = match mode {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_master_secret(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_master_secret(
vm,
outer_partial,
inner_partial,
)?),
Mode::Reduced => {
if let Hmac::Reduced(hmac) = hmac {
Self::Reduced(reduced::PrfFunction::alloc_master_secret(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
Mode::Normal => {
if let Hmac::Normal(hmac) = hmac {
Self::Normal(normal::PrfFunction::alloc_master_secret(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
};
Ok(prf)
}
/// Allocates key expansion.
pub(crate) fn alloc_key_expansion(
mode: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
hmac: Hmac,
) -> Result<Self, FError> {
let prf = match mode {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_key_expansion(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_key_expansion(
vm,
outer_partial,
inner_partial,
)?),
Mode::Reduced => {
if let Hmac::Reduced(hmac) = hmac {
Self::Reduced(reduced::PrfFunction::alloc_key_expansion(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
Mode::Normal => {
if let Hmac::Normal(hmac) = hmac {
Self::Normal(normal::PrfFunction::alloc_key_expansion(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
};
Ok(prf)
}
/// Allocates client finished.
pub(crate) fn alloc_client_finished(
config: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
hmac: Hmac,
) -> Result<Self, FError> {
let prf = match config {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_client_finished(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_client_finished(
vm,
outer_partial,
inner_partial,
)?),
Mode::Reduced => {
if let Hmac::Reduced(hmac) = hmac {
Self::Reduced(reduced::PrfFunction::alloc_client_finished(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
Mode::Normal => {
if let Hmac::Normal(hmac) = hmac {
Self::Normal(normal::PrfFunction::alloc_client_finished(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
};
Ok(prf)
}
/// Allocates server finished.
pub(crate) fn alloc_server_finished(
config: Mode,
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
hmac: Hmac,
) -> Result<Self, FError> {
let prf = match config {
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_server_finished(
vm,
outer_partial,
inner_partial,
)?),
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_server_finished(
vm,
outer_partial,
inner_partial,
)?),
Mode::Reduced => {
if let Hmac::Reduced(hmac) = hmac {
Self::Reduced(reduced::PrfFunction::alloc_server_finished(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
Mode::Normal => {
if let Hmac::Normal(hmac) = hmac {
Self::Normal(normal::PrfFunction::alloc_server_finished(vm, hmac)?)
} else {
unreachable!("modes always match");
}
}
};
Ok(prf)
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
match self {
Prf::Reduced(prf) => prf.wants_flush(),
@@ -111,13 +127,15 @@ impl Prf {
}
}
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
match self {
Prf::Reduced(prf) => prf.flush(vm),
Prf::Normal(prf) => prf.flush(vm),
}
}
/// Sets the seed.
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
match self {
Prf::Reduced(prf) => prf.set_start_seed(seed),
@@ -125,53 +143,52 @@ impl Prf {
}
}
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
/// Returns the PRF output.
pub(crate) fn output(&self) -> Vector<U8> {
match self {
Prf::Reduced(prf) => prf.output(),
Prf::Normal(prf) => prf.output(),
}
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
match self {
Prf::Reduced(prf) => prf.is_complete(),
Prf::Normal(prf) => prf.is_complete(),
}
}
}
#[cfg(test)]
mod tests {
use crate::{
prf::{compute_partial, function::Prf},
test_utils::phash,
hmac::Hmac,
prf::function::Prf,
test_utils::{mock_vm, phash},
Mode,
};
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
};
use rand::{rngs::ThreadRng, Rng};
use rstest::*;
const IPAD: [u8; 64] = [0x36; 64];
const OPAD: [u8; 64] = [0x5c; 64];
#[rstest]
#[case::normal(Mode::Normal)]
#[case::reduced(Mode::Reduced)]
#[tokio::test]
async fn test_phash_reduced() {
let mode = Mode::Reduced;
test_phash(mode).await;
}
#[tokio::test]
async fn test_phash_normal() {
let mode = Mode::Normal;
test_phash(mode).await;
}
async fn test_phash(mode: Mode) {
async fn test_phash(#[case] mode: Mode) {
let mut rng = ThreadRng::default();
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let (mut leader, mut follower) = mock_vm();
let key: [u8; 32] = rng.random();
let start_seed: Vec<u8> = vec![42; 64];
let output_len = 48;
let mut label_seed = b"master secret".to_vec();
label_seed.extend_from_slice(&start_seed);
@@ -182,48 +199,25 @@ mod tests {
leader.assign(leader_key, key).unwrap();
leader.commit(leader_key).unwrap();
let outer_partial_leader = compute_partial(&mut leader, leader_key.into(), OPAD).unwrap();
let inner_partial_leader = compute_partial(&mut leader, leader_key.into(), IPAD).unwrap();
let leader_hmac = Hmac::alloc(&mut leader, leader_key.into(), mode).unwrap();
let mut prf_leader = Prf::alloc_master_secret(
mode,
&mut leader,
outer_partial_leader,
inner_partial_leader,
)
.unwrap();
let mut prf_leader = Prf::alloc_master_secret(mode, &mut leader, leader_hmac).unwrap();
prf_leader.set_start_seed(start_seed.clone());
let mut prf_out_leader = vec![];
for p in prf_leader.output() {
let p_out = leader.decode(p).unwrap();
prf_out_leader.push(p_out)
}
let mut prf_out_leader = leader.decode(prf_leader.output()).unwrap();
let follower_key: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(follower_key).unwrap();
follower.assign(follower_key, key).unwrap();
follower.commit(follower_key).unwrap();
let outer_partial_follower =
compute_partial(&mut follower, follower_key.into(), OPAD).unwrap();
let inner_partial_follower =
compute_partial(&mut follower, follower_key.into(), IPAD).unwrap();
let follower_hmac = Hmac::alloc(&mut follower, follower_key.into(), mode).unwrap();
let mut prf_follower = Prf::alloc_master_secret(
mode,
&mut follower,
outer_partial_follower,
inner_partial_follower,
)
.unwrap();
let mut prf_follower =
Prf::alloc_master_secret(mode, &mut follower, follower_hmac).unwrap();
prf_follower.set_start_seed(start_seed.clone());
let mut prf_out_follower = vec![];
for p in prf_follower.output() {
let p_out = follower.decode(p).unwrap();
prf_out_follower.push(p_out)
}
let mut prf_out_follower = follower.decode(prf_follower.output()).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
@@ -239,19 +233,10 @@ mod tests {
.unwrap();
}
assert_eq!(prf_out_leader.len(), 2);
assert_eq!(prf_out_leader.len(), prf_out_follower.len());
let prf_result_leader: Vec<u8> = prf_out_leader.try_recv().unwrap().unwrap();
let prf_result_follower: Vec<u8> = prf_out_follower.try_recv().unwrap().unwrap();
let prf_result_leader: Vec<u8> = prf_out_leader
.iter_mut()
.flat_map(|p| p.try_recv().unwrap().unwrap())
.collect();
let prf_result_follower: Vec<u8> = prf_out_follower
.iter_mut()
.flat_map(|p| p.try_recv().unwrap().unwrap())
.collect();
let expected = phash(key.to_vec(), &label_seed, iterations);
let expected = &phash(key.to_vec(), &label_seed, iterations)[..output_len];
assert_eq!(prf_result_leader, prf_result_follower);
assert_eq!(prf_result_leader, expected)

View File

@@ -1,24 +1,30 @@
//! Computes the whole PRF in MPC.
//! TLS 1.2 PRF function.
use crate::{hmac::hmac_sha256, PrfError};
use mpz_hash::sha256::Sha256;
use crate::{hmac::normal::HmacNormal, tls12::merge_vectors, FError};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, MemoryExt, Vector, ViewExt,
MemoryExt, Vector, ViewExt,
},
Vm,
};
#[derive(Debug)]
pub(crate) struct PrfFunction {
// The label, e.g. "master secret".
// The human-readable label, e.g. "master secret".
label: &'static [u8],
state: State,
// The start seed and the label, e.g. client_random + server_random + "master_secret".
/// The start seed and the label, e.g. client_random + server_random +
/// "master_secret".
start_seed_label: Option<Vec<u8>>,
a: Vec<PHash>,
p: Vec<PHash>,
seed_label_ref: Vector<U8>,
/// A_Hash functionalities for each iteration instantiated with the PRF
/// secret.
a_hash: Vec<HmacNormal>,
/// P_Hash functionalities for each iteration instantiated with the PRF
/// secret.
p_hash: Vec<HmacNormal>,
output: Vector<U8>,
}
impl PrfFunction {
@@ -27,64 +33,128 @@ impl PrfFunction {
const CF_LABEL: &[u8] = b"client finished";
const SF_LABEL: &[u8] = b"server finished";
/// Allocates master secret.
pub(crate) fn alloc_master_secret(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48, 64)
hmac: HmacNormal,
) -> Result<Self, FError> {
Self::alloc(vm, Self::MS_LABEL, hmac, 48, 64)
}
/// Allocates key expansion.
pub(crate) fn alloc_key_expansion(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40, 64)
hmac: HmacNormal,
) -> Result<Self, FError> {
Self::alloc(vm, Self::KEY_LABEL, hmac, 40, 64)
}
/// Allocates client finished.
pub(crate) fn alloc_client_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12, 32)
hmac: HmacNormal,
) -> Result<Self, FError> {
Self::alloc(vm, Self::CF_LABEL, hmac, 12, 32)
}
/// Allocates server finished.
pub(crate) fn alloc_server_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32)
hmac: HmacNormal,
) -> Result<Self, FError> {
Self::alloc(vm, Self::SF_LABEL, hmac, 12, 32)
}
/// Allocates a new PRF with the given `hmac` instantiated with the PRF
/// secret.
fn alloc(
vm: &mut dyn Vm<Binary>,
label: &'static [u8],
hmac: HmacNormal,
output_len: usize,
seed_len: usize,
) -> Result<Self, FError> {
assert!(output_len > 0, "cannot compute 0 bytes for prf");
let iterations = output_len.div_ceil(32);
let msg_len_a = label.len() + seed_len;
let seed_label_ref: Vector<U8> = vm.alloc_vec(msg_len_a).map_err(FError::vm)?;
vm.mark_public(seed_label_ref).map_err(FError::vm)?;
let mut msg_a = seed_label_ref;
let mut p_out: Vec<Vector<U8>> = Vec::with_capacity(iterations);
let mut a_hash = Vec::with_capacity(iterations);
let mut p_hash = Vec::with_capacity(iterations);
for _ in 0..iterations {
let mut a = HmacNormal::from_other(&hmac)?;
a.set_msg(vm, &[msg_a])?;
let a_out: Vector<U8> = a.output()?.into();
msg_a = a_out;
a_hash.push(a);
let mut p = HmacNormal::from_other(&hmac)?;
p.set_msg(vm, &[a_out, seed_label_ref])?;
p_out.push(p.output()?.into());
p_hash.push(p);
}
Ok(Self {
label,
state: State::WantsSeed,
start_seed_label: None,
seed_label_ref,
a_hash,
p_hash,
output: merge_vectors(vm, p_out, output_len)?,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
let is_computing = match self.state {
State::Computing => true,
State::Finished => false,
let state_wants_flush = match self.state {
State::WantsSeed => self.start_seed_label.is_some(),
_ => false,
};
is_computing && self.start_seed_label.is_some()
state_wants_flush
|| self.a_hash.iter().any(|h| h.wants_flush())
|| self.p_hash.iter().any(|h| h.wants_flush())
}
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
if let State::Computing = self.state {
let a = self.a.first().expect("prf should be allocated");
let msg = *a.msg.first().expect("message for prf should be present");
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
// Flush every HMAC functionality.
self.a_hash.iter_mut().try_for_each(|h| h.flush())?;
self.p_hash.iter_mut().try_for_each(|h| h.flush())?;
let msg_value = self
.start_seed_label
.clone()
.expect("Start seed should have been set");
match self.state {
State::WantsSeed => {
if let Some(seed) = &self.start_seed_label {
vm.assign(self.seed_label_ref, seed.clone())
.map_err(FError::vm)?;
vm.commit(self.seed_label_ref).map_err(FError::vm)?;
vm.assign(msg, msg_value).map_err(PrfError::vm)?;
vm.commit(msg).map_err(PrfError::vm)?;
self.state = State::Finished;
self.state = State::SeedSet;
// Recurse.
self.flush(vm)?;
}
}
State::SeedSet => {
// We are complete when all HMACs are complete.
if self.a_hash.iter().all(|h| h.is_complete())
&& self.p_hash.iter().all(|h| h.is_complete())
{
self.state = State::Complete;
}
}
_ => (),
}
Ok(())
}
/// Sets the seed.
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
let mut start_seed_label = self.label.to_vec();
start_seed_label.extend_from_slice(&seed);
@@ -92,83 +162,20 @@ impl PrfFunction {
self.start_seed_label = Some(start_seed_label);
}
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
self.p.iter().map(|p| p.output).collect()
/// Returns the PRF output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
fn alloc(
vm: &mut dyn Vm<Binary>,
label: &'static [u8],
outer_partial: Sha256,
inner_partial: Sha256,
output_len: usize,
seed_len: usize,
) -> Result<Self, PrfError> {
let mut prf = Self {
label,
state: State::Computing,
start_seed_label: None,
a: vec![],
p: vec![],
};
assert!(output_len > 0, "cannot compute 0 bytes for prf");
let iterations = output_len.div_ceil(32);
let msg_len_a = label.len() + seed_len;
let seed_label_ref: Vector<U8> = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?;
vm.mark_public(seed_label_ref).map_err(PrfError::vm)?;
let mut msg_a = seed_label_ref;
for _ in 0..iterations {
let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), &[msg_a])?;
msg_a = Vector::<U8>::from(a.output);
prf.a.push(a);
let p = PHash::alloc(
vm,
outer_partial.clone(),
inner_partial.clone(),
&[msg_a, seed_label_ref],
)?;
prf.p.push(p);
}
Ok(prf)
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
#[derive(Debug, Clone, Copy)]
enum State {
Computing,
Finished,
}
#[derive(Debug, Clone)]
struct PHash {
msg: Vec<Vector<U8>>,
output: Array<U8, 32>,
}
impl PHash {
fn alloc(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
msg: &[Vector<U8>],
) -> Result<Self, PrfError> {
let mut inner_local = inner_partial;
msg.iter().for_each(|m| inner_local.update(m));
inner_local.compress(vm)?;
let inner_local = inner_local.finalize(vm)?;
let output = hmac_sha256(vm, outer_partial, inner_local)?;
let p_hash = Self {
msg: msg.to_vec(),
output,
};
Ok(p_hash)
}
WantsSeed,
SeedSet,
Complete,
}

View File

@@ -1,46 +1,30 @@
//! Computes some hashes of the PRF locally.
//! TLS 1.2 PRF function.
use std::collections::VecDeque;
use crate::{hmac::hmac_sha256, sha256, state_to_bytes, PrfError};
use mpz_core::bitvec::BitVec;
use mpz_hash::sha256::Sha256;
use crate::{hmac::reduced::HmacReduced, tls12::merge_vectors, FError};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, DecodeFutureTyped, MemoryExt, ViewExt,
MemoryExt, Vector,
},
Vm,
};
#[derive(Debug)]
pub(crate) struct PrfFunction {
// The label, e.g. "master secret".
// The human-readable label, e.g. "master secret".
label: &'static [u8],
// The start seed and the label, e.g. client_random + server_random + "master_secret".
start_seed_label: Option<Vec<u8>>,
iterations: usize,
state: PrfState,
a: VecDeque<AHash>,
p: VecDeque<PHash>,
}
#[derive(Debug)]
enum PrfState {
InnerPartial {
inner_partial: DecodeFutureTyped<BitVec, [u32; 8]>,
},
ComputeA {
iter: usize,
inner_partial: [u32; 8],
msg: Vec<u8>,
},
ComputeP {
iter: usize,
inner_partial: [u32; 8],
a_output: DecodeFutureTyped<BitVec, [u8; 32]>,
},
Done,
state: State,
/// A_Hash functionalities for each iteration instantiated with the PRF
/// secret.
a_hash: VecDeque<HmacReduced>,
/// P_Hash functionalities for each iteration instantiated with the PRF
/// secret.
p_hash: VecDeque<HmacReduced>,
output: Vector<U8>,
}
impl PrfFunction {
@@ -49,111 +33,222 @@ impl PrfFunction {
const CF_LABEL: &[u8] = b"client finished";
const SF_LABEL: &[u8] = b"server finished";
/// Allocates master secret.
pub(crate) fn alloc_master_secret(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48)
hmac: HmacReduced,
) -> Result<Self, FError> {
Self::alloc(vm, Self::MS_LABEL, hmac, 48)
}
/// Allocates key expansion.
pub(crate) fn alloc_key_expansion(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40)
hmac: HmacReduced,
) -> Result<Self, FError> {
Self::alloc(vm, Self::KEY_LABEL, hmac, 40)
}
/// Allocates client finished.
pub(crate) fn alloc_client_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12)
hmac: HmacReduced,
) -> Result<Self, FError> {
Self::alloc(vm, Self::CF_LABEL, hmac, 12)
}
/// Allocates server finished.
pub(crate) fn alloc_server_finished(
vm: &mut dyn Vm<Binary>,
outer_partial: Sha256,
inner_partial: Sha256,
) -> Result<Self, PrfError> {
Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12)
hmac: HmacReduced,
) -> Result<Self, FError> {
Self::alloc(vm, Self::SF_LABEL, hmac, 12)
}
/// Allocates a new PRF with the given `hmac` instantiated with the PRF
/// secret.
fn alloc(
vm: &mut dyn Vm<Binary>,
label: &'static [u8],
hmac: HmacReduced,
output_len: usize,
) -> Result<Self, FError> {
assert!(output_len > 0, "cannot compute 0 bytes for prf");
let iterations = output_len.div_ceil(32);
let mut a_hash = VecDeque::with_capacity(iterations);
let mut p_hash = VecDeque::with_capacity(iterations);
// Create the required amount of HMAC instances.
let mut hmacs = vec![hmac];
for _ in 0..iterations * 2 - 1 {
hmacs.push(HmacReduced::from_other(vm, &hmacs[0])?);
}
let mut p_out: Vec<Vector<U8>> = Vec::with_capacity(iterations);
for _ in 0..iterations {
let a = hmacs.pop().expect("enough instances");
let p = hmacs.pop().expect("enough instances");
// Decode output as soon as it becomes available.
std::mem::drop(vm.decode(a.output()).map_err(FError::vm)?);
p_out.push(p.output().into());
a_hash.push_back(a);
p_hash.push_back(p);
}
Ok(Self {
label,
start_seed_label: None,
state: State::WantsSeed,
a_hash,
p_hash,
output: merge_vectors(vm, p_out, output_len)?,
})
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
!matches!(self.state, PrfState::Done) && self.start_seed_label.is_some()
let state_wants_flush = match self.state {
State::WantsSeed => self.start_seed_label.is_some(),
State::ComputeFirstCycle { .. } => true,
State::ComputeCycle { .. } => true,
State::ComputeLastCycle { .. } => true,
_ => false,
};
state_wants_flush
|| self.a_hash.iter().any(|h| h.wants_flush())
|| self.p_hash.iter().any(|h| h.wants_flush())
}
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
match &mut self.state {
PrfState::InnerPartial { inner_partial } => {
let Some(inner_partial) = inner_partial.try_recv().map_err(PrfError::vm)? else {
return Ok(());
};
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
// Flush every HMAC functionality.
self.a_hash.iter_mut().try_for_each(|h| h.flush(vm))?;
self.p_hash.iter_mut().try_for_each(|h| h.flush(vm))?;
self.state = PrfState::ComputeA {
iter: 1,
inner_partial,
msg: self
.start_seed_label
.clone()
.expect("Start seed should have been set"),
};
self.flush(vm)?;
}
PrfState::ComputeA {
iter,
inner_partial,
msg,
} => {
let a = self.a.pop_front().expect("Prf AHash should be present");
assign_inner_local(vm, a.inner_local, *inner_partial, msg)?;
self.state = PrfState::ComputeP {
iter: *iter,
inner_partial: *inner_partial,
a_output: a.output,
};
}
PrfState::ComputeP {
iter,
inner_partial,
a_output,
} => {
let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else {
return Ok(());
};
let p = self.p.pop_front().expect("Prf PHash should be present");
let mut msg = output.to_vec();
msg.extend_from_slice(
self.start_seed_label
.as_ref()
.expect("Start seed should have been set"),
);
assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?;
if *iter == self.iterations {
self.state = PrfState::Done;
} else {
self.state = PrfState::ComputeA {
iter: *iter + 1,
inner_partial: *inner_partial,
msg: output.to_vec(),
};
// We recurse, so that this PHash and the next AHash could
// be computed in a single VM execute call.
match &self.state {
State::WantsSeed => {
if let Some(seed) = &self.start_seed_label {
self.state = State::ComputeFirstCycle { msg: seed.to_vec() };
// recurse.
self.flush(vm)?;
}
}
State::ComputeFirstCycle { msg } => {
let mut a = self.a_hash.pop_front().expect("not empty");
if !a.is_msg_set() {
a.set_msg(msg)?;
a.flush(vm)?;
}
let out = if a.is_complete() {
let mut a_out = vm.decode(a.output()).map_err(FError::vm)?;
a_out.try_recv().map_err(FError::vm)?
} else {
None
};
match out {
Some(out) => {
self.state = State::ComputeCycle { msg: out.to_vec() };
// Recurse to the next cycle.
self.flush(vm)?;
}
None => {
// Prepare to process this cycle again after VM executes.
self.a_hash.push_front(a);
self.state = State::ComputeFirstCycle { msg: msg.to_vec() };
}
}
}
State::ComputeCycle { msg } => {
if self.p_hash.len() == 1 {
// Recurse to the last cycle.
self.state = State::ComputeLastCycle { msg: msg.to_vec() };
self.flush(vm)?;
return Ok(());
}
let mut a = self.a_hash.pop_front().expect("not empty");
let mut p = self.p_hash.pop_front().expect("not empty");
if !a.is_msg_set() {
a.set_msg(msg)?;
a.flush(vm)?;
}
if !p.is_msg_set() {
let mut p_msg = msg.clone();
p_msg.extend_from_slice(
self.start_seed_label
.as_ref()
.expect("Start seed should have been set"),
);
p.set_msg(&p_msg)?;
p.flush(vm)?;
}
if !p.is_complete() {
// Prepare to process this cycle again after VM executes.
self.a_hash.push_front(a);
self.p_hash.push_front(p);
self.state = State::ComputeCycle { msg: msg.to_vec() };
return Ok(());
}
let out = if a.is_complete() {
let mut a_out = vm.decode(a.output()).map_err(FError::vm)?;
a_out.try_recv().map_err(FError::vm)?
} else {
None
};
match out {
Some(out) => {
// Recurse to the next cycle.
self.state = State::ComputeCycle { msg: out.to_vec() };
self.flush(vm)?;
}
None => {
// Prepare to process this cycle again after VM executes.
self.a_hash.push_front(a);
self.p_hash.push_front(p);
self.state = State::ComputeCycle { msg: msg.to_vec() };
}
}
}
State::ComputeLastCycle { msg } => {
let mut p = self.p_hash.pop_front().expect("not empty");
if !p.is_msg_set() {
let mut p_msg = msg.clone();
p_msg.extend_from_slice(
self.start_seed_label
.as_ref()
.expect("Start seed should have been set"),
);
p.set_msg(&p_msg)?;
p.flush(vm)?;
}
if !p.is_complete() {
// Prepare to process this cycle again after VM executes.
self.p_hash.push_front(p);
self.state = State::ComputeLastCycle { msg: msg.to_vec() };
} else {
self.state = State::Complete;
}
}
_ => (),
}
Ok(())
}
/// Sets the seed.
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
let mut start_seed_label = self.label.to_vec();
start_seed_label.extend_from_slice(&seed);
@@ -161,88 +256,33 @@ impl PrfFunction {
self.start_seed_label = Some(start_seed_label);
}
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
self.p.iter().map(|p| p.output).collect()
/// Returns the PRF output.
pub(crate) fn output(&self) -> Vector<U8> {
self.output
}
fn alloc(
vm: &mut dyn Vm<Binary>,
label: &'static [u8],
outer_partial: Sha256,
inner_partial: Sha256,
len: usize,
) -> Result<Self, PrfError> {
assert!(len > 0, "cannot compute 0 bytes for prf");
let iterations = len.div_ceil(32);
let (inner_partial, _) = inner_partial
.state()
.expect("state should be set for inner_partial");
let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?;
let mut prf = Self {
label,
start_seed_label: None,
iterations,
state: PrfState::InnerPartial { inner_partial },
a: VecDeque::new(),
p: VecDeque::new(),
};
for _ in 0..iterations {
// setup A[i]
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
let output = vm.decode(output).map_err(PrfError::vm)?;
let a_hash = AHash {
inner_local,
output,
};
prf.a.push_front(a_hash);
// setup P[i]
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
let p_hash = PHash {
inner_local,
output,
};
prf.p.push_front(p_hash);
}
Ok(prf)
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
fn assign_inner_local(
vm: &mut dyn Vm<Binary>,
inner_local: Array<U8, 32>,
inner_partial: [u32; 8],
msg: &[u8],
) -> Result<(), PrfError> {
let inner_local_value = sha256(inner_partial, 64, msg);
vm.mark_public(inner_local).map_err(PrfError::vm)?;
vm.assign(inner_local, state_to_bytes(inner_local_value))
.map_err(PrfError::vm)?;
vm.commit(inner_local).map_err(PrfError::vm)?;
Ok(())
}
/// Like PHash but stores the output as the decoding future because in the
/// reduced Prf we need to decode this output.
#[derive(Debug)]
struct AHash {
inner_local: Array<U8, 32>,
output: DecodeFutureTyped<BitVec, [u8; 32]>,
}
#[derive(Debug, Clone, Copy)]
struct PHash {
inner_local: Array<U8, 32>,
output: Array<U8, 32>,
#[derive(Debug, PartialEq)]
enum State {
WantsSeed,
/// To minimize the amount of VM execute calls, the PRF iterations are
/// divided into cycles.
/// Starting with iteration count i == 1, each cycle computes a tuple
/// (A_Hash(i), P_Hash(i-1)). Thus, during the first cycle, only A_Hash(1)
/// is computed and during the last cycle only P_Hash(i) is computed.
ComputeFirstCycle {
msg: Vec<u8>,
},
ComputeCycle {
msg: Vec<u8>,
},
ComputeLastCycle {
msg: Vec<u8>,
},
Complete,
}

View File

@@ -0,0 +1,2 @@
pub(crate) mod function;
pub(crate) use function::Prf;

View File

@@ -1,103 +0,0 @@
use crate::{
prf::{function::Prf, merge_outputs},
PrfError, PrfOutput, SessionKeys,
};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, FromRaw, ToRaw,
},
Vm,
};
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Initialized,
SessionKeys {
client_random: Option<[u8; 32]>,
master_secret: Prf,
key_expansion: Prf,
client_finished: Prf,
server_finished: Prf,
},
ClientFinished {
client_finished: Prf,
server_finished: Prf,
},
ServerFinished {
server_finished: Prf,
},
Complete,
Error,
}
impl State {
pub(crate) fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
pub(crate) fn prf_output(&self, vm: &mut dyn Vm<Binary>) -> Result<PrfOutput, PrfError> {
let State::SessionKeys {
key_expansion,
client_finished,
server_finished,
..
} = self
else {
return Err(PrfError::state(
"Prf output can only be computed while in \"SessionKeys\" state",
));
};
let keys = get_session_keys(key_expansion.output(), vm)?;
let cf_vd = get_client_finished_vd(client_finished.output(), vm)?;
let sf_vd = get_server_finished_vd(server_finished.output(), vm)?;
let output = PrfOutput { keys, cf_vd, sf_vd };
Ok(output)
}
}
fn get_session_keys(
output: Vec<Array<U8, 32>>,
vm: &mut dyn Vm<Binary>,
) -> Result<SessionKeys, PrfError> {
let mut keys = merge_outputs(vm, output, 40)?;
debug_assert!(keys.len() == 40, "session keys len should be 40");
let server_iv = Array::<U8, 4>::try_from(keys.split_off(36)).unwrap();
let client_iv = Array::<U8, 4>::try_from(keys.split_off(32)).unwrap();
let server_write_key = Array::<U8, 16>::try_from(keys.split_off(16)).unwrap();
let client_write_key = Array::<U8, 16>::try_from(keys).unwrap();
let session_keys = SessionKeys {
client_write_key,
server_write_key,
client_iv,
server_iv,
};
Ok(session_keys)
}
fn get_client_finished_vd(
output: Vec<Array<U8, 32>>,
vm: &mut dyn Vm<Binary>,
) -> Result<Array<U8, 12>, PrfError> {
let cf_vd = merge_outputs(vm, output, 12)?;
let cf_vd = <Array<U8, 12> as FromRaw<Binary>>::from_raw(cf_vd.to_raw());
Ok(cf_vd)
}
fn get_server_finished_vd(
output: Vec<Array<U8, 32>>,
vm: &mut dyn Vm<Binary>,
) -> Result<Array<U8, 12>, PrfError> {
let sf_vd = merge_outputs(vm, output, 12)?;
let sf_vd = <Array<U8, 12> as FromRaw<Binary>>::from_raw(sf_vd.to_raw());
Ok(sf_vd)
}

View File

@@ -1,9 +1,20 @@
use crate::{sha256, state_to_bytes};
use crate::hmac::clear;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
use mpz_vm_core::memory::correlated::Delta;
use rand::{rngs::StdRng, Rng, SeedableRng};
pub(crate) const SHA256_IV: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
];
pub(crate) fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
let mut label_start_seed = b"master secret".to_vec();
@@ -57,7 +68,7 @@ pub(crate) fn phash(key: Vec<u8>, seed: &[u8], iterations: usize) -> Vec<u8> {
a_cache.push(seed.to_vec());
for i in 0..iterations {
let a_i = hmac_sha256(key.clone(), &a_cache[i]);
let a_i = clear::hmac_sha256(&key, &a_cache[i]);
a_cache.push(a_i.to_vec());
}
@@ -67,64 +78,13 @@ pub(crate) fn phash(key: Vec<u8>, seed: &[u8], iterations: usize) -> Vec<u8> {
let mut a_i_seed = a_cache[i + 1].clone();
a_i_seed.extend_from_slice(seed);
let hash = hmac_sha256(key.clone(), &a_i_seed);
let hash = clear::hmac_sha256(&key, &a_i_seed);
output.extend_from_slice(&hash);
}
output
}
pub(crate) fn hmac_sha256(key: Vec<u8>, msg: &[u8]) -> [u8; 32] {
let outer_partial = compute_outer_partial(key.clone());
let inner_local = compute_inner_local(key, msg);
let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local));
state_to_bytes(hmac)
}
pub(crate) fn compute_outer_partial(mut key: Vec<u8>) -> [u32; 8] {
assert!(key.len() <= 64);
key.resize(64, 0_u8);
let key_padded: [u8; 64] = key
.into_iter()
.map(|b| b ^ 0x5c)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
compress_256(SHA256_IV, &key_padded)
}
pub(crate) fn compute_inner_local(mut key: Vec<u8>, msg: &[u8]) -> [u32; 8] {
assert!(key.len() <= 64);
key.resize(64, 0_u8);
let key_padded: [u8; 64] = key
.into_iter()
.map(|b| b ^ 0x36)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
let state = compress_256(SHA256_IV, &key_padded);
sha256(state, 64, msg)
}
pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] {
use sha2::{
compress256,
digest::{
block_buffer::{BlockBuffer, Eager},
generic_array::typenum::U64,
},
};
let mut buffer = BlockBuffer::<U64, Eager>::default();
buffer.digest_blocks(msg, |b| compress256(&mut state, b));
state
}
// Borrowed from Rustls for testing
// https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs
mod ring_prf {
@@ -244,3 +204,21 @@ fn test_prf_reference_sf() {
assert_eq!(sf_vd, expected_sf_vd);
}
#[test]
fn test_key_schedule_reference_sf() {
use ring_prf::prf as prf_ref;
let mut rng = StdRng::from_seed([4; 32]);
let ms: [u8; 48] = rng.random();
let label: &[u8] = b"server finished";
let handshake_hash: [u8; 32] = rng.random();
let sf_vd = prf_sf_vd(ms, handshake_hash);
let mut expected_sf_vd: [u8; 12] = [0; 12];
prf_ref(&mut expected_sf_vd, &ms, label, &handshake_hash);
assert_eq!(sf_vd, expected_sf_vd);
}

View File

@@ -0,0 +1,556 @@
//! Functionality for computing HMAC-SHA-256-based TLS 1.2 PRF.
use std::{fmt::Debug, sync::Arc};
use mpz_circuits::{Circuit, CircuitBuilder};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, StaticSize, Vector,
},
Call, CallableExt, Vm,
};
use tracing::instrument;
use crate::{hmac::Hmac, prf::Prf, tls12::state::State, FError, Mode};
mod state;
/// Functionality for computing HMAC-SHA-256-based TLS 1.2 PRF.
#[derive(Debug)]
pub struct Tls12Prf {
mode: Mode,
state: State,
}
impl Tls12Prf {
/// Creates a new instance of the PRF.
///
/// # Arguments
///
/// `mode` - The PRF mode.
pub fn new(mode: Mode) -> Tls12Prf {
Self {
mode,
state: State::Initialized,
}
}
/// Allocates resources for the PRF.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `pms` - The pre-master secret.
#[instrument(level = "debug", skip_all, err)]
pub fn alloc(
&mut self,
vm: &mut dyn Vm<Binary>,
pms: Array<U8, 32>,
) -> Result<PrfOutput, FError> {
let State::Initialized = self.state.take() else {
return Err(FError::state("PRF not in initialized state"));
};
let mode = self.mode;
let hmac_pms = Hmac::alloc(vm, pms.into(), mode)?;
let master_secret = Prf::alloc_master_secret(mode, vm, hmac_pms)?;
let hmac_ms1: Hmac = Hmac::alloc(vm, master_secret.output(), mode)?;
let hmac_ms2 = Hmac::from_other(vm, &hmac_ms1)?;
let hmac_ms3 = Hmac::from_other(vm, &hmac_ms1)?;
let key_expansion = Prf::alloc_key_expansion(mode, vm, hmac_ms1)?;
let client_finished = Prf::alloc_client_finished(mode, vm, hmac_ms2)?;
let server_finished = Prf::alloc_server_finished(mode, vm, hmac_ms3)?;
self.state = State::SessionKeys {
client_random: None,
master_secret,
key_expansion,
client_finished,
server_finished,
};
self.state.prf_output()
}
/// Sets the client random.
///
/// # Arguments
///
/// * `random` - The client random.
#[instrument(level = "debug", skip_all, err)]
pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), FError> {
let State::SessionKeys { client_random, .. } = &mut self.state else {
return Err(FError::state("PRF not set up"));
};
*client_random = Some(random);
Ok(())
}
/// Sets the server random.
///
/// # Arguments
///
/// * `random` - The server random.
#[instrument(level = "debug", skip_all, err)]
pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), FError> {
let State::SessionKeys {
client_random,
master_secret,
key_expansion,
..
} = &mut self.state
else {
return Err(FError::state("PRF not set up"));
};
let client_random = client_random.expect("Client random should have been set by now");
let server_random = random;
let mut seed_ms = client_random.to_vec();
seed_ms.extend_from_slice(&server_random);
master_secret.set_start_seed(seed_ms);
let mut seed_ke = server_random.to_vec();
seed_ke.extend_from_slice(&client_random);
key_expansion.set_start_seed(seed_ke);
Ok(())
}
/// Sets the client finished handshake hash.
///
/// # Arguments
///
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> {
let State::ClientFinished {
client_finished, ..
} = &mut self.state
else {
return Err(FError::state("PRF not in client finished state"));
};
let seed_cf = handshake_hash.to_vec();
client_finished.set_start_seed(seed_cf);
Ok(())
}
/// Sets the server finished handshake hash.
///
/// # Arguments
///
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> {
let State::ServerFinished { server_finished } = &mut self.state else {
return Err(FError::state("PRF not in server finished state"));
};
let seed_sf = handshake_hash.to_vec();
server_finished.set_start_seed(seed_sf);
Ok(())
}
/// Returns if the PRF needs to be flushed.
pub fn wants_flush(&self) -> bool {
match &self.state {
State::SessionKeys {
master_secret,
key_expansion,
..
} => master_secret.wants_flush() || key_expansion.wants_flush(),
State::ClientFinished {
client_finished, ..
} => client_finished.wants_flush(),
State::ServerFinished { server_finished } => server_finished.wants_flush(),
_ => false,
}
}
/// Flushes the PRF.
pub fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
self.state = match self.state.take() {
State::SessionKeys {
client_random,
mut master_secret,
mut key_expansion,
client_finished,
server_finished,
} => {
master_secret.flush(vm)?;
key_expansion.flush(vm)?;
if master_secret.is_complete() && key_expansion.is_complete() {
State::ClientFinished {
client_finished,
server_finished,
}
} else {
State::SessionKeys {
client_random,
master_secret,
key_expansion,
client_finished,
server_finished,
}
}
}
State::ClientFinished {
mut client_finished,
server_finished,
} => {
client_finished.flush(vm)?;
if client_finished.is_complete() {
State::ServerFinished { server_finished }
} else {
State::ClientFinished {
client_finished,
server_finished,
}
}
}
State::ServerFinished {
mut server_finished,
} => {
server_finished.flush(vm)?;
if server_finished.is_complete() {
State::Complete
} else {
State::ServerFinished { server_finished }
}
}
other => other,
};
Ok(())
}
}
/// PRF output.
#[derive(Debug, Clone, Copy)]
pub struct PrfOutput {
/// TLS session keys.
pub keys: SessionKeys,
/// Client finished verify data.
pub cf_vd: Array<U8, 12>,
/// Server finished verify data.
pub sf_vd: Array<U8, 12>,
}
/// Session keys computed by the PRF.
#[derive(Debug, Clone, Copy)]
pub struct SessionKeys {
/// Client write key.
pub client_write_key: Array<U8, 16>,
/// Server write key.
pub server_write_key: Array<U8, 16>,
/// Client IV.
pub client_iv: Array<U8, 4>,
/// Server IV.
pub server_iv: Array<U8, 4>,
}
/// Merges vectors, returning the merged vector truncated to the `output_bytes`
/// length.
pub(crate) fn merge_vectors(
vm: &mut dyn Vm<Binary>,
inputs: Vec<Vector<U8>>,
output_bytes: usize,
) -> Result<Vector<U8>, FError> {
let len = inputs.iter().map(|inp| inp.len()).sum();
assert!(output_bytes <= len);
let bits = len * U8::SIZE;
let circ = gen_merge_circ(bits);
let mut builder = Call::builder(circ);
for &input in inputs.iter() {
builder = builder.arg(input);
}
let call = builder.build().map_err(FError::vm)?;
let mut output: Vector<U8> = vm.call(call).map_err(FError::vm)?;
output.truncate(output_bytes);
Ok(output)
}
fn gen_merge_circ(size: usize) -> Arc<Circuit> {
let mut builder = CircuitBuilder::new();
let inputs = (0..size).map(|_| builder.add_input()).collect::<Vec<_>>();
for input in inputs.chunks_exact(8) {
for byte in input.chunks_exact(8) {
for &feed in byte.iter() {
let output = builder.add_id_gate(feed);
builder.add_output(output);
}
}
}
Arc::new(builder.build().expect("merge circuit is valid"))
}
#[cfg(test)]
mod tests {
use crate::{
test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
tls12::merge_vectors,
Mode, SessionKeys, Tls12Prf,
};
use mpz_common::context::test_st_context;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, Vector, ViewExt},
Execute,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rstest::*;
#[rstest]
#[case::normal(Mode::Normal)]
#[case::reduced(Mode::Reduced)]
#[tokio::test]
async fn test_tls12prf(#[case] mode: Mode) {
let mut rng = StdRng::seed_from_u64(1);
// Test input.
let pms: [u8; 32] = rng.random();
let client_random: [u8; 32] = rng.random();
let server_random: [u8; 32] = rng.random();
let cf_hs_hash: [u8; 32] = rng.random();
let sf_hs_hash: [u8; 32] = rng.random();
// Expected output.
let ms_expected = prf_ms(pms, client_random, server_random);
let [cwk_expected, swk_expected, civ_expected, siv_expected] =
prf_keys(ms_expected, client_random, server_random);
let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap();
let swk_expected: [u8; 16] = swk_expected.try_into().unwrap();
let civ_expected: [u8; 4] = civ_expected.try_into().unwrap();
let siv_expected: [u8; 4] = siv_expected.try_into().unwrap();
let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash);
let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash);
let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap();
let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap();
// Set up vm and prf.
let (mut ctx_a, mut ctx_b) = test_st_context(128);
let (mut leader, mut follower) = mock_vm();
let leader_pms: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(leader_pms).unwrap();
leader.assign(leader_pms, pms).unwrap();
leader.commit(leader_pms).unwrap();
let follower_pms: Array<U8, 32> = follower.alloc().unwrap();
follower.mark_public(follower_pms).unwrap();
follower.assign(follower_pms, pms).unwrap();
follower.commit(follower_pms).unwrap();
let mut prf_leader = Tls12Prf::new(mode);
let mut prf_follower = Tls12Prf::new(mode);
let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap();
let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap();
// client_random and server_random.
prf_leader.set_client_random(client_random).unwrap();
prf_follower.set_client_random(client_random).unwrap();
prf_leader.set_server_random(server_random).unwrap();
prf_follower.set_server_random(server_random).unwrap();
let SessionKeys {
client_write_key: cwk_leader,
server_write_key: swk_leader,
client_iv: civ_leader,
server_iv: siv_leader,
} = leader_prf_out.keys;
let mut cwk_leader = leader.decode(cwk_leader).unwrap();
let mut swk_leader = leader.decode(swk_leader).unwrap();
let mut civ_leader = leader.decode(civ_leader).unwrap();
let mut siv_leader = leader.decode(siv_leader).unwrap();
let SessionKeys {
client_write_key: cwk_follower,
server_write_key: swk_follower,
client_iv: civ_follower,
server_iv: siv_follower,
} = follower_prf_out.keys;
let mut cwk_follower = follower.decode(cwk_follower).unwrap();
let mut swk_follower = follower.decode(swk_follower).unwrap();
let mut civ_follower = follower.decode(civ_follower).unwrap();
let mut siv_follower = follower.decode(siv_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
let cwk_leader = cwk_leader.try_recv().unwrap().unwrap();
let swk_leader = swk_leader.try_recv().unwrap().unwrap();
let civ_leader = civ_leader.try_recv().unwrap().unwrap();
let siv_leader = siv_leader.try_recv().unwrap().unwrap();
let cwk_follower = cwk_follower.try_recv().unwrap().unwrap();
let swk_follower = swk_follower.try_recv().unwrap().unwrap();
let civ_follower = civ_follower.try_recv().unwrap().unwrap();
let siv_follower = siv_follower.try_recv().unwrap().unwrap();
assert_eq!(cwk_leader, cwk_follower);
assert_eq!(swk_leader, swk_follower);
assert_eq!(civ_leader, civ_follower);
assert_eq!(siv_leader, siv_follower);
assert_eq!(cwk_leader, cwk_expected);
assert_eq!(swk_leader, swk_expected);
assert_eq!(civ_leader, civ_expected);
assert_eq!(siv_leader, siv_expected);
// client finished.
prf_leader.set_cf_hash(cf_hs_hash).unwrap();
prf_follower.set_cf_hash(cf_hs_hash).unwrap();
let cf_vd_leader = leader_prf_out.cf_vd;
let cf_vd_follower = follower_prf_out.cf_vd;
let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap();
let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap();
let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap();
assert_eq!(cf_vd_leader, cf_vd_follower);
assert_eq!(cf_vd_leader, cf_vd_expected);
// server finished.
prf_leader.set_sf_hash(sf_hs_hash).unwrap();
prf_follower.set_sf_hash(sf_hs_hash).unwrap();
let sf_vd_leader = leader_prf_out.sf_vd;
let sf_vd_follower = follower_prf_out.sf_vd;
let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap();
let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap();
while prf_leader.wants_flush() || prf_follower.wants_flush() {
tokio::try_join!(
async {
prf_leader.flush(&mut leader).unwrap();
leader.execute_all(&mut ctx_a).await
},
async {
prf_follower.flush(&mut follower).unwrap();
follower.execute_all(&mut ctx_b).await
}
)
.unwrap();
}
let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap();
let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap();
assert_eq!(sf_vd_leader, sf_vd_follower);
assert_eq!(sf_vd_leader, sf_vd_expected);
}
#[tokio::test]
async fn test_merge_outputs() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let input1: [u8; 32] = std::array::from_fn(|i| i as u8);
let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32);
let mut expected = input1.to_vec();
expected.extend_from_slice(&input2);
expected.truncate(48);
// leader
let input1_leader: Vector<U8> = leader.alloc_vec(32).unwrap();
let input2_leader: Vector<U8> = leader.alloc_vec(32).unwrap();
leader.mark_public(input1_leader).unwrap();
leader.mark_public(input2_leader).unwrap();
leader.assign(input1_leader, input1.to_vec()).unwrap();
leader.assign(input2_leader, input2.to_vec()).unwrap();
leader.commit(input1_leader).unwrap();
leader.commit(input2_leader).unwrap();
let merged_leader =
merge_vectors(&mut leader, vec![input1_leader, input2_leader], 48).unwrap();
let mut merged_leader = leader.decode(merged_leader).unwrap();
// follower
let input1_follower: Vector<U8> = follower.alloc_vec(32).unwrap();
let input2_follower: Vector<U8> = follower.alloc_vec(32).unwrap();
follower.mark_public(input1_follower).unwrap();
follower.mark_public(input2_follower).unwrap();
follower.assign(input1_follower, input1.to_vec()).unwrap();
follower.assign(input2_follower, input2.to_vec()).unwrap();
follower.commit(input1_follower).unwrap();
follower.commit(input2_follower).unwrap();
let merged_follower =
merge_vectors(&mut follower, vec![input1_follower, input2_follower], 48).unwrap();
let mut merged_follower = follower.decode(merged_follower).unwrap();
tokio::try_join!(
leader.execute_all(&mut ctx_a),
follower.execute_all(&mut ctx_b)
)
.unwrap();
let merged_leader = merged_leader.try_recv().unwrap().unwrap();
let merged_follower = merged_follower.try_recv().unwrap().unwrap();
assert_eq!(merged_leader, merged_follower);
assert_eq!(merged_leader, expected);
}
}

View File

@@ -0,0 +1,79 @@
use crate::{prf::Prf, FError, PrfOutput, SessionKeys};
use mpz_vm_core::memory::{binary::U8, Array};
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Initialized,
SessionKeys {
client_random: Option<[u8; 32]>,
master_secret: Prf,
key_expansion: Prf,
client_finished: Prf,
server_finished: Prf,
},
ClientFinished {
client_finished: Prf,
server_finished: Prf,
},
ServerFinished {
server_finished: Prf,
},
Complete,
Error,
}
impl State {
pub(crate) fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
pub(crate) fn prf_output(&self) -> Result<PrfOutput, FError> {
let State::SessionKeys {
key_expansion,
client_finished,
server_finished,
..
} = self
else {
return Err(FError::state(
"Prf output can only be computed while in \"SessionKeys\" state",
));
};
let keys = get_session_keys(
key_expansion
.output()
.try_into()
.expect("session keys are 40 bytes"),
);
let output = PrfOutput {
keys,
cf_vd: client_finished
.output()
.try_into()
.expect("client finished is 12 bytes"),
sf_vd: server_finished
.output()
.try_into()
.expect("server finished is 12 bytes"),
};
Ok(output)
}
}
fn get_session_keys(keys: Array<U8, 40>) -> SessionKeys {
let client_write_key = keys.get::<16>(0).expect("within bounds");
let server_write_key = keys.get::<16>(16).expect("within bounds");
let client_iv = keys.get::<4>(32).expect("within bounds");
let server_iv = keys.get::<4>(36).expect("within bounds");
SessionKeys {
client_write_key,
server_write_key,
client_iv,
server_iv,
}
}

View File

@@ -0,0 +1,605 @@
//! Functionality for computing HMAC-SHA256-based TLS 1.3 key schedule.
use std::mem;
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, MemoryExt,
},
OneTimePad, Vm,
};
use rand::RngCore;
use crate::{
hmac::Hmac,
kdf::{expand::hkdf_expand_label, extract::HkdfExtract},
tls13::{application::ApplicationSecrets, handshake::HandshakeSecrets},
FError, Mode,
};
mod application;
mod handshake;
/// Functionality role.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Role {
/// Leader.
///
/// The leader learns handshake secrets and locally finishes the handshake.
Leader,
/// Follower.
Follower,
}
/// Functionality for computing HMAC-SHA-256-based TLS 1.3 key schedule.
pub struct Tls13KeySched {
mode: Mode,
role: Role,
// Allocated master secret.
master_secret: Option<HkdfExtract>,
// Allocated application secrets.
application: Option<ApplicationSecrets>,
state: State,
}
impl Tls13KeySched {
/// Creates a new functionality.
pub fn new(mode: Mode, role: Role) -> Tls13KeySched {
Self {
mode,
role,
application: None,
master_secret: None,
state: State::Initialized,
}
}
/// Allocates the functionality with the given pre-master secret.
pub fn alloc(&mut self, vm: &mut dyn Vm<Binary>, pms: Array<U8, 32>) -> Result<(), FError> {
let State::Initialized = self.state.take() else {
return Err(FError::state("not in initialized state"));
};
let mut hs_secrets = HandshakeSecrets::new(self.mode);
let (cs, ss, derived_secret) = hs_secrets.alloc(vm, pms)?;
let (masked_cs, cs_otp, masked_ss, ss_otp) = match self.role {
Role::Leader => {
let mut cs_otp = [0u8; 32];
let mut ss_otp = [0u8; 32];
rand::rng().fill_bytes(&mut cs_otp);
rand::rng().fill_bytes(&mut ss_otp);
let masked_cs = vm.mask_private(cs, cs_otp).map_err(FError::vm)?;
let masked_ss = vm.mask_private(ss, ss_otp).map_err(FError::vm)?;
(masked_cs, Some(cs_otp), masked_ss, Some(ss_otp))
}
Role::Follower => {
let masked_cs = vm.mask_blind(cs).map_err(FError::vm)?;
let masked_ss = vm.mask_blind(ss).map_err(FError::vm)?;
(masked_cs, None, masked_ss, None)
}
};
// Decode as soon as values are known.
std::mem::drop(vm.decode(masked_cs).map_err(FError::vm)?);
std::mem::drop(vm.decode(masked_ss).map_err(FError::vm)?);
let hmac_derived = Hmac::alloc(vm, derived_secret, self.mode)?;
let master_secret = HkdfExtract::alloc(self.mode, vm, [0u8; 32], hmac_derived)?;
let mut aps = ApplicationSecrets::new(self.mode);
aps.alloc(vm, master_secret.output())?;
self.master_secret = Some(master_secret);
self.application = Some(aps);
self.state = State::Handshake {
secrets: hs_secrets,
masked_cs,
masked_ss,
cs_otp,
ss_otp,
};
Ok(())
}
/// Whether this functionality needs to be flushed.
pub fn wants_flush(&self) -> bool {
match &self.state {
State::Handshake { secrets, .. } => secrets.wants_flush(),
State::WantsDecodedKeys { .. } => true,
State::MasterSecret(ms) => ms.wants_flush(),
State::Application(app) => app.wants_flush(),
_ => false,
}
}
/// Flushes the functionality.
pub fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
match &mut self.state {
State::Handshake { secrets, .. } => {
secrets.flush(vm)?;
if secrets.is_complete() {
match self.state.take() {
State::Handshake {
masked_cs,
masked_ss,
cs_otp,
ss_otp,
..
} => {
self.state = State::WantsDecodedKeys {
masked_cs,
masked_ss,
cs_otp,
ss_otp,
};
// Recurse.
self.flush(vm)?;
return Ok(());
}
_ => unreachable!(),
}
}
}
State::WantsDecodedKeys {
masked_cs,
masked_ss,
cs_otp,
ss_otp,
} => {
let mut masked_cs = vm.decode(*masked_cs).map_err(FError::vm)?;
let Some(masked_cs) = masked_cs.try_recv().map_err(FError::vm)? else {
return Ok(());
};
let mut masked_ss = vm.decode(*masked_ss).map_err(FError::vm)?;
let Some(masked_ss) = masked_ss.try_recv().map_err(FError::vm)? else {
return Ok(());
};
let (ckey, civ, skey, siv) = if self.role == Role::Leader {
let cs_otp = cs_otp.expect("leader knows cs otp");
let ss_otp = ss_otp.expect("leader knows ss otp");
let mut cs = masked_cs;
let mut ss = masked_ss;
cs.iter_mut().zip(cs_otp).for_each(|(cs, otp)| {
*cs ^= otp;
});
ss.iter_mut().zip(ss_otp).for_each(|(ss, otp)| {
*ss ^= otp;
});
let ckey: [u8; 16] = hkdf_expand_label(&cs, b"key", &[], 16)
.try_into()
.expect("output is 16 bytes");
let civ: [u8; 12] = hkdf_expand_label(&cs, b"iv", &[], 12)
.try_into()
.expect("output is 12 bytes");
let skey: [u8; 16] = hkdf_expand_label(&ss, b"key", &[], 16)
.try_into()
.expect("output is 16 bytes");
let siv: [u8; 12] = hkdf_expand_label(&ss, b"iv", &[], 12)
.try_into()
.expect("output is 12 bytes");
(Some(ckey), Some(civ), Some(skey), Some(siv))
} else {
(None, None, None, None)
};
self.state = State::KeysDecoded {
ckey,
civ,
skey,
siv,
}
}
State::MasterSecret(ms) => {
ms.flush(vm)?;
if ms.is_complete() {
self.state = State::WantsHandshakeHash;
}
}
State::Application(app) => {
app.flush(vm)?;
if app.is_complete() {
self.state = State::Complete(app.keys()?);
}
}
_ => (),
}
Ok(())
}
/// Sets the hash of the ClientHello message.
pub fn set_hello_hash(&mut self, hello_hash: [u8; 32]) -> Result<(), FError> {
match &mut self.state {
State::Handshake { secrets, .. } => {
secrets.set_hello_hash(hello_hash)?;
Ok(())
}
_ => Err(FError::state("not in Handshake state")),
}
}
/// Returns handshake keys.
pub fn handshake_keys(&mut self) -> Result<HandshakeKeys, FError> {
if self.role != Role::Leader {
return Err(FError::state("only leader can access handshake keys"));
}
match self.state {
State::KeysDecoded {
ckey,
civ,
skey,
siv,
} => Ok(HandshakeKeys {
client_write_key: ckey.expect("leader knows key"),
client_iv: civ.expect("leader knows key"),
server_write_key: skey.expect("leader knows key"),
server_iv: siv.expect("leader knows key"),
}),
_ => Err(FError::state("not in HandshakeComplete state")),
}
}
/// Continues the key schedule to derive application keys.
///
/// Used after the handshake keys are computed and before the handshake
/// hash is set.
pub fn continue_to_app_keys(&mut self) -> Result<(), FError> {
match self.state {
State::KeysDecoded { .. } => {
let ms = mem::take(&mut self.master_secret).expect("master secret is set");
self.state = State::MasterSecret(ms);
Ok(())
}
_ => Err(FError::state("not in KeysDecoded state")),
}
}
/// Sets the handshake hash.
pub fn set_handshake_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> {
match &mut self.state {
State::WantsHandshakeHash => {
let mut app =
mem::take(&mut self.application).expect("application secrets are set");
app.set_handshake_hash(handshake_hash)?;
self.state = State::Application(app);
Ok(())
}
_ => Err(FError::state("not in WantsHandshakeHash state")),
}
}
/// Returns VM references to the application keys.
pub fn application_keys(&mut self) -> Result<ApplicationKeys, FError> {
match self.state {
State::Complete(keys) => Ok(keys),
_ => Err(FError::state("not in Complete state")),
}
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub(crate) enum State {
Initialized,
/// The state in which some of the handshake secrets are computed in MPC.
Handshake {
secrets: HandshakeSecrets,
masked_cs: Array<U8, 32>,
masked_ss: Array<U8, 32>,
cs_otp: Option<[u8; 32]>,
ss_otp: Option<[u8; 32]>,
},
/// The state after all handshake-related MPC operations were completed
/// and the keys need to be decoded.
WantsDecodedKeys {
masked_cs: Array<U8, 32>,
masked_ss: Array<U8, 32>,
cs_otp: Option<[u8; 32]>,
ss_otp: Option<[u8; 32]>,
},
/// The state after the handshake keys were decoded and made known to the
/// leader.
KeysDecoded {
ckey: Option<[u8; 16]>,
civ: Option<[u8; 12]>,
skey: Option<[u8; 16]>,
siv: Option<[u8; 12]>,
},
/// The state in which the master secret is computed.
///
/// Computing master secret before handshake hash is set can potentially
/// improve overall performance.
MasterSecret(HkdfExtract),
/// The state in which the master secret has been computed and the
/// handshake hash is expected to be set.
WantsHandshakeHash,
/// The state in which the application secrets are derived.
Application(ApplicationSecrets),
Complete(ApplicationKeys),
Error,
}
impl State {
pub(crate) fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
}
/// Handshake keys computed by the key schedule.
#[derive(Debug, Clone, Copy)]
pub struct HandshakeKeys {
/// Client write key.
pub client_write_key: [u8; 16],
/// Server write key.
pub server_write_key: [u8; 16],
/// Client IV.
pub client_iv: [u8; 12],
/// Server IV.
pub server_iv: [u8; 12],
}
/// Application keys computed by the key schedule.
#[derive(Debug, Clone, Copy)]
pub struct ApplicationKeys {
/// Client write key.
pub client_write_key: Array<U8, 16>,
/// Server write key.
pub server_write_key: Array<U8, 16>,
/// Client IV.
pub client_iv: Array<U8, 12>,
/// Server IV.
pub server_iv: Array<U8, 12>,
}
#[cfg(test)]
mod tests {
use crate::{
test_utils::mock_vm,
tls13::{Role, Tls13KeySched},
ApplicationKeys, HandshakeKeys, Mode,
};
use mpz_common::{context::test_st_context, Context};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, MemoryExt, ViewExt,
},
Vm,
};
use rstest::*;
#[rstest]
#[case::normal(Mode::Normal)]
#[case::reduced(Mode::Reduced)]
#[tokio::test]
async fn test_tls13_key_sched(#[case] mode: Mode) {
let (
pms,
hello_hash,
handshake_hash,
ckey_hs,
civ_hs,
skey_hs,
siv_hs,
ckey_app,
civ_app,
skey_app,
siv_app,
) = test_fixtures();
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
// PMS is a private output from previous MPC computations not known
// to either party. For simplicity, it is marked public in this test.
let pms: [u8; 32] = pms.try_into().unwrap();
fn setup_ks(
vm: &mut (dyn Vm<Binary> + Send),
pms: [u8; 32],
mode: Mode,
role: Role,
) -> Tls13KeySched {
let secret: Array<U8, 32> = vm.alloc().unwrap();
vm.mark_public(secret).unwrap();
vm.assign(secret, pms).unwrap();
vm.commit(secret).unwrap();
let mut ks = Tls13KeySched::new(mode, role);
ks.alloc(vm, secret).unwrap();
ks
}
let mut leader_ks = setup_ks(&mut leader, pms, mode, Role::Leader);
let mut follower_ks = setup_ks(&mut follower, pms, mode, Role::Follower);
async fn run_ks(
vm: &mut (dyn Vm<Binary> + Send),
ks: &mut Tls13KeySched,
ctx: &mut Context,
role: Role,
mode: Mode,
hello_hash: Vec<u8>,
handshake_hash: Vec<u8>,
) -> Result<
(
Option<HandshakeKeys>,
([u8; 16], [u8; 12], [u8; 16], [u8; 12]),
),
Box<dyn std::error::Error>,
> {
let res = async move {
vm.execute_all(ctx).await.unwrap();
flush_execute(ks, vm, ctx, false).await;
ks.set_hello_hash(hello_hash.try_into().unwrap()).unwrap();
// One extra flush to process decoded handshake secrets.
flush_execute(ks, vm, ctx, true).await;
let hs_keys = if role == Role::Leader {
Some(ks.handshake_keys().unwrap())
} else {
None
};
ks.continue_to_app_keys().unwrap();
flush_execute(ks, vm, ctx, false).await;
ks.set_handshake_hash(handshake_hash.try_into().unwrap())
.unwrap();
if mode == Mode::Reduced {
// One extra flush to process decoded inner_partial.
flush_execute(ks, vm, ctx, true).await;
} else {
flush_execute(ks, vm, ctx, false).await;
}
let ApplicationKeys {
client_write_key,
client_iv,
server_write_key,
server_iv,
} = ks.application_keys().unwrap();
let mut ckey_fut = vm.decode(client_write_key).unwrap();
let mut civ_fut = vm.decode(client_iv).unwrap();
let mut skey_fut = vm.decode(server_write_key).unwrap();
let mut siv_fut = vm.decode(server_iv).unwrap();
vm.execute_all(ctx).await.unwrap();
let ckey = ckey_fut.try_recv().unwrap().unwrap();
let civ = civ_fut.try_recv().unwrap().unwrap();
let skey = skey_fut.try_recv().unwrap().unwrap();
let siv = siv_fut.try_recv().unwrap().unwrap();
(hs_keys, (ckey, civ, skey, siv))
}
.await;
Ok(res)
}
let (out_leader, out_follower) = tokio::try_join!(
run_ks(
&mut leader,
&mut leader_ks,
&mut ctx_a,
Role::Leader,
mode,
hello_hash.clone(),
handshake_hash.clone()
),
run_ks(
&mut follower,
&mut follower_ks,
&mut ctx_b,
Role::Follower,
mode,
hello_hash,
handshake_hash
)
)
.unwrap();
let hs_keys_leader = out_leader.0.unwrap();
assert_eq!(
(
hs_keys_leader.client_write_key.to_vec(),
hs_keys_leader.client_iv.to_vec(),
hs_keys_leader.server_write_key.to_vec(),
hs_keys_leader.server_iv.to_vec()
),
(ckey_hs, civ_hs, skey_hs, siv_hs)
);
let app_keys_leader = out_leader.1;
let app_keys_follower = out_follower.1;
assert_eq!(app_keys_leader, app_keys_follower);
assert_eq!(
app_keys_leader,
(
ckey_app.try_into().unwrap(),
civ_app.try_into().unwrap(),
skey_app.try_into().unwrap(),
siv_app.try_into().unwrap()
)
);
}
async fn flush_execute(
ks: &mut Tls13KeySched,
vm: &mut (dyn Vm<Binary> + Send),
ctx: &mut Context,
// Whether after executing the VM, one extra flush is required.
extra_flush: bool,
) {
assert!(ks.wants_flush());
ks.flush(vm).unwrap();
vm.execute_all(ctx).await.unwrap();
if extra_flush {
assert!(ks.wants_flush());
ks.flush(vm).unwrap();
}
assert!(!ks.wants_flush())
}
// Reference values from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06
#[allow(clippy::type_complexity)]
fn test_fixtures() -> (
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
Vec<u8>,
) {
(
// PMS
from_hex_str("81 51 d1 46 4c 1b 55 53 36 23 b9 c2 24 6a 6a 0e 6e 7e 18 50 63 e1 4a fd af f0 b6 e1 c6 1a 86 42"),
// HELLO_HASH
from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"),
// HANDSHAKE_HASH
from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"),
// CKEY_HS
from_hex_str("26 79 a4 3e 1d 76 78 40 34 ea 17 97 d5 ad 26 49"),
// CIV_HS
from_hex_str("54 82 40 52 90 dd 0d 2f 81 c0 d9 42"),
// SKEY_HS
from_hex_str("c6 6c b1 ae c5 19 df 44 c9 1e 10 99 55 11 ac 8b"),
// SIV_HS
from_hex_str("f7 f6 88 4c 49 81 71 6c 2d 0d 29 a4"),
// CKEY_APP
from_hex_str("88 b9 6a d6 86 c8 4b e5 5a ce 18 a5 9c ce 5c 87"),
// CIV_APP
from_hex_str("b9 9d c5 8c d5 ff 5a b0 82 fd ad 19"),
// SKEY_APP
from_hex_str("a6 88 eb b5 ac 82 6d 6f 42 d4 5c 0c c4 4b 9b 7d"),
// SIV_APP
from_hex_str("c1 ca d4 42 5a 43 8b 5d e7 14 83 0a"),
)
}
fn from_hex_str(s: &str) -> Vec<u8> {
hex::decode(s.split_whitespace().collect::<String>()).unwrap()
}
}

View File

@@ -0,0 +1,233 @@
use crate::{
hmac::Hmac,
kdf::expand::{HkdfExpand, EMPTY_CTX},
ApplicationKeys, FError, Mode,
};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Vector,
},
Vm,
};
/// Functionality for computing application secrets of TLS 1.3 key schedule.
#[derive(Debug)]
pub(crate) struct ApplicationSecrets {
mode: Mode,
state: State,
client_secret: Option<HkdfExpand>,
server_secret: Option<HkdfExpand>,
client_application_key: Option<HkdfExpand>,
client_application_iv: Option<HkdfExpand>,
server_application_key: Option<HkdfExpand>,
server_application_iv: Option<HkdfExpand>,
}
impl ApplicationSecrets {
/// Creates a new functionality.
pub(crate) fn new(mode: Mode) -> ApplicationSecrets {
Self {
mode,
state: State::Initialized,
client_secret: None,
server_secret: None,
client_application_key: None,
client_application_iv: None,
server_application_key: None,
server_application_iv: None,
}
}
/// Allocates the functionality with the given `master_secret`.
pub(crate) fn alloc(
&mut self,
vm: &mut dyn Vm<Binary>,
master_secret: Vector<U8>,
) -> Result<(), FError> {
let State::Initialized = self.state.take() else {
return Err(FError::state("not in Initialized state"));
};
let mode = self.mode;
let hmac_ms1 = Hmac::alloc(vm, master_secret, mode)?;
let hmac_ms2 = Hmac::from_other(vm, &hmac_ms1)?;
let client_secret = HkdfExpand::alloc(mode, vm, hmac_ms1, b"c ap traffic", None, 32, 32)?;
let server_secret = HkdfExpand::alloc(mode, vm, hmac_ms2, b"s ap traffic", None, 32, 32)?;
let hmac_cs1 = Hmac::alloc(vm, client_secret.output(), mode)?;
let hmac_cs2 = Hmac::from_other(vm, &hmac_cs1)?;
let hmac_ss1 = Hmac::alloc(vm, server_secret.output(), mode)?;
let hmac_ss2 = Hmac::from_other(vm, &hmac_ss1)?;
let client_application_key =
HkdfExpand::alloc(mode, vm, hmac_cs1, b"key", Some(&EMPTY_CTX), 0, 16)?;
let client_application_iv =
HkdfExpand::alloc(mode, vm, hmac_cs2, b"iv", Some(&EMPTY_CTX), 0, 12)?;
let server_application_key =
HkdfExpand::alloc(mode, vm, hmac_ss1, b"key", Some(&EMPTY_CTX), 0, 16)?;
let server_application_iv =
HkdfExpand::alloc(mode, vm, hmac_ss2, b"iv", Some(&EMPTY_CTX), 0, 12)?;
self.state = State::WantsHandshakeHash;
self.client_secret = Some(client_secret);
self.server_secret = Some(server_secret);
self.client_application_key = Some(client_application_key);
self.client_application_iv = Some(client_application_iv);
self.server_application_key = Some(server_application_key);
self.server_application_iv = Some(server_application_iv);
Ok(())
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
let client_secret = self.client_secret.as_ref().expect("functionality was set");
let server_secret = self.server_secret.as_ref().expect("functionality was set");
let client_application_key = self
.client_application_key
.as_ref()
.expect("functionality was set");
let client_application_iv = self
.client_application_iv
.as_ref()
.expect("functionality was set");
let server_application_key = self
.server_application_key
.as_ref()
.expect("functionality was set");
let server_application_iv = self
.server_application_iv
.as_ref()
.expect("functionality was set");
let state_wants_flush = matches!(&self.state, State::HandshakeHashSet(..));
state_wants_flush
|| client_secret.wants_flush()
|| server_secret.wants_flush()
|| client_application_key.wants_flush()
|| client_application_iv.wants_flush()
|| server_application_key.wants_flush()
|| server_application_iv.wants_flush()
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
let client_secret = self.client_secret.as_mut().expect("functionality was set");
let server_secret = self.server_secret.as_mut().expect("functionality was set");
let client_application_key = self
.client_application_key
.as_mut()
.expect("functionality was set");
let client_application_iv = self
.client_application_iv
.as_mut()
.expect("functionality was set");
let server_application_key = self
.server_application_key
.as_mut()
.expect("functionality was set");
let server_application_iv = self
.server_application_iv
.as_mut()
.expect("functionality was set");
client_secret.flush(vm)?;
server_secret.flush(vm)?;
client_application_key.flush(vm)?;
client_application_iv.flush(vm)?;
server_application_key.flush(vm)?;
server_application_iv.flush(vm)?;
if let State::HandshakeHashSet(hash) = &self.state {
if !client_secret.is_ctx_set() {
client_secret.set_ctx(hash)?;
client_secret.flush(vm)?;
}
if !server_secret.is_ctx_set() {
server_secret.set_ctx(hash)?;
server_secret.flush(vm)?;
}
if client_application_iv.is_complete()
&& client_application_key.is_complete()
&& client_secret.is_complete()
&& server_application_iv.is_complete()
&& server_application_key.is_complete()
&& server_secret.is_complete()
{
self.state = State::Complete(ApplicationKeys {
client_write_key: client_application_key
.output()
.try_into()
.expect("key length is 16 bytes"),
client_iv: client_application_iv
.output()
.try_into()
.expect("iv length is 12 bytes"),
server_write_key: server_application_key
.output()
.try_into()
.expect("key length is 16 bytes"),
server_iv: server_application_iv
.output()
.try_into()
.expect("iv length is 12 bytes"),
});
}
}
Ok(())
}
/// Sets the handshake hash.
pub(crate) fn set_handshake_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> {
match &mut self.state {
State::WantsHandshakeHash => {
self.state = State::HandshakeHashSet(handshake_hash);
Ok(())
}
_ => Err(FError::state("not in WantsHandshakeHash state")),
}
}
/// Returns the application keys.
pub(crate) fn keys(&mut self) -> Result<ApplicationKeys, FError> {
match self.state {
State::Complete(keys) => Ok(keys),
_ => Err(FError::state("not in Complete state")),
}
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete { .. })
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Initialized,
/// Wants handshake hash to be set.
WantsHandshakeHash,
/// Handshake hash has been set.
HandshakeHashSet([u8; 32]),
Complete(ApplicationKeys),
Error,
}
impl State {
pub(crate) fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
}

View File

@@ -0,0 +1,206 @@
use crate::{
hmac::{normal::HmacNormal, Hmac},
kdf::{expand::HkdfExpand, extract::HkdfExtractPrivIkm},
FError, Mode,
};
use mpz_vm_core::{
memory::{
binary::{Binary, U8},
Array, Vector,
},
Vm,
};
// INNER_PARTIAL and OUTER_PARTIAL were computed using the code below:
//
// // A deterministic derived secret for handshake for SHA-256 ciphersuites.
// // see https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06
// let derived_secret: Vec<u8> = vec![
// 0x6f, 0x26, 0x15, 0xa1, 0x08, 0xc7, 0x02, 0xc5, 0x67, 0x8f, 0x54,
// 0xfc, 0x9d, 0xba, 0xb6, 0x97, 0x16, 0xc0, 0x76, 0x18, 0x9c, 0x48,
// 0x25, 0x0c, 0xeb, 0xea, 0xc3, 0x57, 0x6c, 0x36, 0x11, 0xba];
//
// let inner_partial = clear::compute_inner_partial(derived_secret.clone());
// let outer_partial = clear::compute_outer_partial(derived_secret);
/// A deterministic inner partial hash state of the derived secret for
/// handshake for SHA-256 ciphersuites.
const INNER_PARTIAL: [u32; 8] = [
2335507740, 2200227439, 3546272834, 83913483, 301355998, 2266431524, 1402092146, 439257589,
];
/// A deterministic inner partial hash state of the derived secret for
/// handshake for SHA-256 ciphersuites.
const OUTER_PARTIAL: [u32; 8] = [
582556975, 2818161237, 3127925320, 2797531207, 4122647441, 3290806166, 3682628262, 2419579842,
];
/// The digest of SHA256("").
const EMPTY_HASH: [u8; 32] = [
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
];
/// Functionality for computing handshake secrets of TLS 1.3 key schedule.
#[derive(Debug)]
pub(crate) struct HandshakeSecrets {
mode: Mode,
state: State,
handshake_secret: Option<HkdfExtractPrivIkm>,
client_secret: Option<HkdfExpand>,
server_secret: Option<HkdfExpand>,
derived_secret: Option<HkdfExpand>,
}
impl HandshakeSecrets {
/// Creates a new functionality.
pub(crate) fn new(mode: Mode) -> HandshakeSecrets {
Self {
mode,
state: State::Initialized,
handshake_secret: None,
client_secret: None,
server_secret: None,
derived_secret: None,
}
}
/// Allocates the functionality with the given pre-master secret.
///
/// Returns client_handshake_traffic_secret,
/// server_handshake_traffic_secret, and derived_secret for master_secret.
#[allow(clippy::type_complexity)]
pub(crate) fn alloc(
&mut self,
vm: &mut dyn Vm<Binary>,
pms: Array<U8, 32>,
) -> Result<(Array<U8, 32>, Array<U8, 32>, Vector<U8>), FError> {
let State::Initialized = self.state.take() else {
return Err(FError::state("not in Initialized state"));
};
let mode = self.mode;
let hmac = HmacNormal::alloc_with_state(vm, INNER_PARTIAL, OUTER_PARTIAL)?;
let handshake_secret = HkdfExtractPrivIkm::alloc(vm, pms, hmac)?;
let hmac_hs1 = Hmac::alloc(vm, handshake_secret.output(), mode)?;
let hmac_hs2 = Hmac::from_other(vm, &hmac_hs1)?;
let hmac_hs3 = Hmac::from_other(vm, &hmac_hs1)?;
let client_secret = HkdfExpand::alloc(mode, vm, hmac_hs1, b"c hs traffic", None, 32, 32)?;
let server_secret = HkdfExpand::alloc(mode, vm, hmac_hs2, b"s hs traffic", None, 32, 32)?;
// Optimization: by computing now the derived_secret for
// master_secret in parallel with cs and ss, we save communication
// rounds when we are in the reduced mode.
let derived_secret =
HkdfExpand::alloc(mode, vm, hmac_hs3, b"derived", Some(&EMPTY_HASH), 32, 32)?;
let cs_out: Array<U8, 32> = client_secret
.output()
.try_into()
.expect("client secret is 32 bytes");
let ss_out = server_secret
.output()
.try_into()
.expect("server secret is 32 bytes");
let derived_output = derived_secret.output();
self.handshake_secret = Some(handshake_secret);
self.client_secret = Some(client_secret);
self.server_secret = Some(server_secret);
self.derived_secret = Some(derived_secret);
self.state = State::WantsHelloHash;
Ok((cs_out, ss_out, derived_output))
}
/// Whether this functionality needs to be flushed.
pub(crate) fn wants_flush(&self) -> bool {
let client_secret = self.client_secret.as_ref().expect("functionality was set");
let server_secret = self.server_secret.as_ref().expect("functionality was set");
let derived_secret = self.derived_secret.as_ref().expect("functionality was set");
let handshake_secret = self
.handshake_secret
.as_ref()
.expect("functionality was set");
let state_wants_flush = matches!(&self.state, State::HelloHashSet(..));
state_wants_flush
|| client_secret.wants_flush()
|| server_secret.wants_flush()
|| derived_secret.wants_flush()
|| handshake_secret.wants_flush()
}
/// Flushes the functionality.
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), FError> {
let client_secret = self.client_secret.as_mut().expect("functionality was set");
let server_secret = self.server_secret.as_mut().expect("functionality was set");
let derived_secret = self.derived_secret.as_mut().expect("functionality was set");
let handshake_secret = self
.handshake_secret
.as_mut()
.expect("functionality was set");
client_secret.flush(vm)?;
derived_secret.flush(vm)?;
handshake_secret.flush();
server_secret.flush(vm)?;
if let State::HelloHashSet(hash) = &mut self.state {
client_secret.set_ctx(hash)?;
client_secret.flush(vm)?;
server_secret.set_ctx(hash)?;
server_secret.flush(vm)?;
if handshake_secret.is_complete()
&& client_secret.is_complete()
&& server_secret.is_complete()
&& derived_secret.is_complete()
{
self.state = State::Complete;
}
}
Ok(())
}
/// Sets the hash of the ClientHello message.
pub(crate) fn set_hello_hash(&mut self, hello_hash: [u8; 32]) -> Result<(), FError> {
match &mut self.state {
State::WantsHelloHash => {
self.state = State::HelloHashSet(hello_hash);
Ok(())
}
_ => Err(FError::state("not in WantsHelloHash state")),
}
}
/// Whether this functionality is complete.
pub(crate) fn is_complete(&self) -> bool {
matches!(self.state, State::Complete)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum State {
Initialized,
WantsHelloHash,
HelloHashSet([u8; 32]),
Complete,
Error,
}
impl State {
pub(crate) fn take(&mut self) -> State {
std::mem::replace(self, State::Error)
}
}

View File

@@ -5,7 +5,7 @@ description = "Implementation of the 3-party key-exchange protocol"
keywords = ["tls", "mpc", "2pc", "pms", "key-exchange"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]
@@ -40,7 +40,6 @@ tokio = { workspace = true, features = ["sync"] }
[dev-dependencies]
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
mpz-ideal-vm = { workspace = true }
rand_core = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }

View File

@@ -459,7 +459,9 @@ mod tests {
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_fields::UniformRand;
use mpz_ideal_vm::IdealVm;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_memory_core::correlated::Delta;
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
use mpz_share_conversion::ideal::{
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
};
@@ -482,8 +484,7 @@ mod tests {
async fn test_key_exchange() {
let mut rng = StdRng::seed_from_u64(0).compat();
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let (mut gen, mut ev) = mock_vm();
let leader_private_key = SecretKey::random(&mut rng);
let follower_private_key = SecretKey::random(&mut rng);
@@ -624,8 +625,7 @@ mod tests {
async fn test_malicious_key_exchange(#[case] malicious: Malicious) {
let mut rng = StdRng::seed_from_u64(0);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let (mut gen, mut ev) = mock_vm();
let leader_private_key = SecretKey::random(&mut rng.compat_by_ref());
let follower_private_key = SecretKey::random(&mut rng.compat_by_ref());
@@ -704,8 +704,7 @@ mod tests {
#[tokio::test]
async fn test_circuit() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let gen = IdealVm::new();
let ev = IdealVm::new();
let (gen, ev) = mock_vm();
let share_a0_bytes = [5_u8; 32];
let share_a1_bytes = [2_u8; 32];
@@ -835,4 +834,16 @@ mod tests {
(leader, follower)
}
fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
}

View File

@@ -8,7 +8,7 @@
//! with the server alone and forward all messages from and to the follower.
//!
//! A detailed description of this protocol can be found in our documentation
//! <https://tlsnotary.org/docs/mpc/key_exchange>.
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]

View File

@@ -26,7 +26,8 @@ pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) {
#[cfg(test)]
mod tests {
use mpz_ideal_vm::IdealVm;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender};
use super::*;
use crate::KeyExchange;
@@ -39,12 +40,12 @@ mod tests {
is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
IdealVm,
Garbler<IdealCOTSender>,
>(leader);
is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
IdealVm,
Evaluator<IdealCOTReceiver>,
>(follower);
}
}

View File

@@ -4,7 +4,7 @@
//! protocol has semi-honest security.
//!
//! The protocol is described in
//! <https://tlsnotary.org/docs/mpc/key_exchange>
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>
use crate::{KeyExchangeError, Role};
use mpz_common::{Context, Flush};

View File

@@ -5,7 +5,7 @@ description = "Core types for TLSNotary"
keywords = ["tls", "mpc", "2pc", "types"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]
@@ -13,7 +13,6 @@ workspace = true
[features]
default = []
mozilla-certs = ["dep:webpki-root-certs", "dep:webpki-roots"]
fixtures = [
"dep:hex",
"dep:tlsn-data-fixtures",
@@ -45,8 +44,7 @@ sha2 = { workspace = true }
thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
web-time = { workspace = true }
webpki-roots = { workspace = true, optional = true }
webpki-root-certs = { workspace = true, optional = true }
webpki-roots = { workspace = true }
rustls-webpki = { workspace = true, features = ["ring"] }
rustls-pki-types = { workspace = true }
itybity = { workspace = true }
@@ -59,7 +57,5 @@ generic-array = { workspace = true }
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-attestation = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true }

View File

@@ -1,7 +0,0 @@
//! Configuration types.
pub mod prove;
pub mod prover;
pub mod tls;
pub mod tls_commit;
pub mod verifier;

View File

@@ -1,189 +0,0 @@
//! Proving configuration.
use rangeset::set::{RangeSet, ToRangeSet};
use serde::{Deserialize, Serialize};
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};
/// Configuration to prove information to the verifier.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveConfig {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl ProveConfig {
/// Creates a new builder.
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
ProveConfigBuilder::new(transcript)
}
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the sent and received ranges of the transcript to be revealed,
/// respectively.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
self.transcript_commit.as_ref()
}
/// Returns a request.
pub fn to_request(&self) -> ProveRequest {
ProveRequest {
server_identity: self.server_identity,
reveal: self.reveal.clone(),
transcript_commit: self
.transcript_commit
.clone()
.map(|config| config.to_request()),
}
}
}
/// Builder for [`ProveConfig`].
#[derive(Debug)]
pub struct ProveConfigBuilder<'a> {
transcript: &'a Transcript,
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl<'a> ProveConfigBuilder<'a> {
/// Creates a new builder.
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
server_identity: false,
reveal: None,
transcript_commit: None,
}
}
/// Proves the server identity.
pub fn server_identity(&mut self) -> &mut Self {
self.server_identity = true;
self
}
/// Configures transcript commitments.
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
self.transcript_commit = Some(transcript_commit);
self
}
/// Reveals the given ranges of the transcript.
pub fn reveal(
&mut self,
direction: Direction,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
let idx = ranges.to_range_set();
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(ProveConfigError(ErrorRepr::IndexOutOfBounds {
direction,
actual: idx.end().unwrap_or(0),
len: self.transcript.len_of_direction(direction),
}));
}
let (sent, recv) = self.reveal.get_or_insert_default();
match direction {
Direction::Sent => sent.union_mut(&idx),
Direction::Received => recv.union_mut(&idx),
}
Ok(self)
}
/// Reveals the given ranges of the sent data transcript.
pub fn reveal_sent(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
self.reveal(Direction::Sent, ranges)
}
/// Reveals all of the sent data transcript.
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigError> {
let len = self.transcript.len_of_direction(Direction::Sent);
let (sent, _) = self.reveal.get_or_insert_default();
sent.union_mut(&(0..len));
Ok(self)
}
/// Reveals the given ranges of the received data transcript.
pub fn reveal_recv(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
self.reveal(Direction::Received, ranges)
}
/// Reveals all of the received data transcript.
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigError> {
let len = self.transcript.len_of_direction(Direction::Received);
let (_, recv) = self.reveal.get_or_insert_default();
recv.union_mut(&(0..len));
Ok(self)
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigError> {
Ok(ProveConfig {
server_identity: self.server_identity,
reveal: self.reveal,
transcript_commit: self.transcript_commit,
})
}
}
/// Request to prove statements about the connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveRequest {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitRequest>,
}
impl ProveRequest {
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the sent and received ranges of the transcript to be revealed,
/// respectively.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitRequest> {
self.transcript_commit.as_ref()
}
}
/// Error for [`ProveConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProveConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
IndexOutOfBounds {
direction: Direction,
actual: usize,
len: usize,
},
}

View File

@@ -1,33 +0,0 @@
//! Prover configuration.
use serde::{Deserialize, Serialize};
/// Prover configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProverConfig {}
impl ProverConfig {
/// Creates a new builder.
pub fn builder() -> ProverConfigBuilder {
ProverConfigBuilder::default()
}
}
/// Builder for [`ProverConfig`].
#[derive(Debug, Default)]
pub struct ProverConfigBuilder {}
impl ProverConfigBuilder {
/// Builds the configuration.
pub fn build(self) -> Result<ProverConfig, ProverConfigError> {
Ok(ProverConfig {})
}
}
/// Error for [`ProverConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProverConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {}

View File

@@ -1,111 +0,0 @@
//! TLS client configuration.
use serde::{Deserialize, Serialize};
use crate::{
connection::ServerName,
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
/// TLS client configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsClientConfig {
server_name: ServerName,
/// Root certificates.
root_store: RootCertStore,
/// Certificate chain and a matching private key for client
/// authentication.
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsClientConfig {
/// Creates a new builder.
pub fn builder() -> TlsConfigBuilder {
TlsConfigBuilder::default()
}
/// Returns the server name.
pub fn server_name(&self) -> &ServerName {
&self.server_name
}
/// Returns the root certificates.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
}
/// Returns a certificate chain and a matching private key for client
/// authentication.
pub fn client_auth(&self) -> Option<&(Vec<CertificateDer>, PrivateKeyDer)> {
self.client_auth.as_ref()
}
}
/// Builder for [`TlsClientConfig`].
#[derive(Debug, Default)]
pub struct TlsConfigBuilder {
server_name: Option<ServerName>,
root_store: Option<RootCertStore>,
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfigBuilder {
/// Sets the server name.
pub fn server_name(mut self, server_name: ServerName) -> Self {
self.server_name = Some(server_name);
self
}
/// Sets the root certificates to use for verifying the server's
/// certificate.
pub fn root_store(mut self, store: RootCertStore) -> Self {
self.root_store = Some(store);
self
}
/// Sets a DER-encoded certificate chain and a matching private key for
/// client authentication.
///
/// Often the chain will consist of a single end-entity certificate.
///
/// # Arguments
///
/// * `cert_key` - A tuple containing the certificate chain and the private
/// key.
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth(mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> Self {
self.client_auth = Some(cert_key);
self
}
/// Builds the TLS configuration.
pub fn build(self) -> Result<TlsClientConfig, TlsConfigError> {
let server_name = self.server_name.ok_or(ErrorRepr::MissingField {
field: "server_name",
})?;
let root_store = self.root_store.ok_or(ErrorRepr::MissingField {
field: "root_store",
})?;
Ok(TlsClientConfig {
server_name,
root_store,
client_auth: self.client_auth,
})
}
}
/// TLS configuration error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("tls config error")]
enum ErrorRepr {
#[error("missing required field: {field}")]
MissingField { field: &'static str },
}

View File

@@ -1,94 +0,0 @@
//! TLS commitment configuration.
pub mod mpc;
use serde::{Deserialize, Serialize};
/// TLS commitment configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCommitConfig {
protocol: TlsCommitProtocolConfig,
}
impl TlsCommitConfig {
/// Creates a new builder.
pub fn builder() -> TlsCommitConfigBuilder {
TlsCommitConfigBuilder::default()
}
/// Returns the protocol configuration.
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
&self.protocol
}
/// Returns a TLS commitment request.
pub fn to_request(&self) -> TlsCommitRequest {
TlsCommitRequest {
config: self.protocol.clone(),
}
}
}
/// Builder for [`TlsCommitConfig`].
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TlsCommitConfigBuilder {
protocol: Option<TlsCommitProtocolConfig>,
}
impl TlsCommitConfigBuilder {
/// Sets the protocol configuration.
pub fn protocol<C>(mut self, protocol: C) -> Self
where
C: Into<TlsCommitProtocolConfig>,
{
self.protocol = Some(protocol.into());
self
}
/// Builds the configuration.
pub fn build(self) -> Result<TlsCommitConfig, TlsCommitConfigError> {
let protocol = self
.protocol
.ok_or(ErrorRepr::MissingField { name: "protocol" })?;
Ok(TlsCommitConfig { protocol })
}
}
/// TLS commitment protocol configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TlsCommitProtocolConfig {
/// MPC-TLS configuration.
Mpc(mpc::MpcTlsConfig),
}
impl From<mpc::MpcTlsConfig> for TlsCommitProtocolConfig {
fn from(config: mpc::MpcTlsConfig) -> Self {
Self::Mpc(config)
}
}
/// TLS commitment request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCommitRequest {
config: TlsCommitProtocolConfig,
}
impl TlsCommitRequest {
/// Returns the protocol configuration.
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
&self.config
}
}
/// Error for [`TlsCommitConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsCommitConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
}

View File

@@ -1,241 +0,0 @@
//! MPC-TLS commitment protocol configuration.
use serde::{Deserialize, Serialize};
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
/// MPC-TLS commitment protocol configuration.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(try_from = "unchecked::MpcTlsConfigUnchecked")]
pub struct MpcTlsConfig {
/// Maximum number of bytes that can be sent.
max_sent_data: usize,
/// Maximum number of application data records that can be sent.
max_sent_records: Option<usize>,
/// Maximum number of bytes that can be decrypted online, i.e. while the
/// MPC-TLS connection is active.
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of received application data records that can be
/// decrypted online, i.e. while the MPC-TLS connection is active.
max_recv_records_online: Option<usize>,
/// Whether the `deferred decryption` feature is toggled on from the start
/// of the MPC-TLS connection.
defer_decryption_from_start: bool,
/// Network settings.
network: NetworkSetting,
}
impl MpcTlsConfig {
/// Creates a new builder.
pub fn builder() -> MpcTlsConfigBuilder {
MpcTlsConfigBuilder::default()
}
/// Returns the maximum number of bytes that can be sent.
pub fn max_sent_data(&self) -> usize {
self.max_sent_data
}
/// Returns the maximum number of application data records that can
/// be sent.
pub fn max_sent_records(&self) -> Option<usize> {
self.max_sent_records
}
/// Returns the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(&self) -> usize {
self.max_recv_data_online
}
/// Returns the maximum number of bytes that can be received.
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}
/// Returns the maximum number of received application data records that
/// can be decrypted online.
pub fn max_recv_records_online(&self) -> Option<usize> {
self.max_recv_records_online
}
/// Returns whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(&self) -> bool {
self.defer_decryption_from_start
}
/// Returns the network settings.
pub fn network(&self) -> NetworkSetting {
self.network
}
}
fn validate(config: MpcTlsConfig) -> Result<MpcTlsConfig, MpcTlsConfigError> {
if config.max_recv_data_online > config.max_recv_data {
return Err(ErrorRepr::InvalidValue {
name: "max_recv_data_online",
reason: format!(
"must be <= max_recv_data ({} > {})",
config.max_recv_data_online, config.max_recv_data
),
}
.into());
}
Ok(config)
}
/// Builder for [`MpcTlsConfig`].
#[derive(Debug, Default)]
pub struct MpcTlsConfigBuilder {
max_sent_data: Option<usize>,
max_sent_records: Option<usize>,
max_recv_data_online: Option<usize>,
max_recv_data: Option<usize>,
max_recv_records_online: Option<usize>,
defer_decryption_from_start: Option<bool>,
network: Option<NetworkSetting>,
}
impl MpcTlsConfigBuilder {
/// Sets the maximum number of bytes that can be sent.
pub fn max_sent_data(mut self, max_sent_data: usize) -> Self {
self.max_sent_data = Some(max_sent_data);
self
}
/// Sets the maximum number of application data records that can be sent.
pub fn max_sent_records(mut self, max_sent_records: usize) -> Self {
self.max_sent_records = Some(max_sent_records);
self
}
/// Sets the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(mut self, max_recv_data_online: usize) -> Self {
self.max_recv_data_online = Some(max_recv_data_online);
self
}
/// Sets the maximum number of bytes that can be received.
pub fn max_recv_data(mut self, max_recv_data: usize) -> Self {
self.max_recv_data = Some(max_recv_data);
self
}
/// Sets the maximum number of received application data records that can
/// be decrypted online.
pub fn max_recv_records_online(mut self, max_recv_records_online: usize) -> Self {
self.max_recv_records_online = Some(max_recv_records_online);
self
}
/// Sets whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(mut self, defer_decryption_from_start: bool) -> Self {
self.defer_decryption_from_start = Some(defer_decryption_from_start);
self
}
/// Sets the network settings.
pub fn network(mut self, network: NetworkSetting) -> Self {
self.network = Some(network);
self
}
/// Builds the configuration.
pub fn build(self) -> Result<MpcTlsConfig, MpcTlsConfigError> {
let Self {
max_sent_data,
max_sent_records,
max_recv_data_online,
max_recv_data,
max_recv_records_online,
defer_decryption_from_start,
network,
} = self;
let max_sent_data = max_sent_data.ok_or(ErrorRepr::MissingField {
name: "max_sent_data",
})?;
let max_recv_data_online = max_recv_data_online.unwrap_or(DEFAULT_MAX_RECV_ONLINE);
let max_recv_data = max_recv_data.ok_or(ErrorRepr::MissingField {
name: "max_recv_data",
})?;
let defer_decryption_from_start = defer_decryption_from_start.unwrap_or(true);
let network = network.unwrap_or_default();
validate(MpcTlsConfig {
max_sent_data,
max_sent_records,
max_recv_data_online,
max_recv_data,
max_recv_records_online,
defer_decryption_from_start,
network,
})
}
}
/// Settings for the network environment.
///
/// Provides optimization options to adapt the protocol to different network
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum NetworkSetting {
/// Reduces network round-trips at the expense of consuming more network
/// bandwidth.
Bandwidth,
/// Reduces network bandwidth utilization at the expense of more network
/// round-trips.
#[default]
Latency,
}
/// Error for [`MpcTlsConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct MpcTlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
#[error("invalid value for field({name}): {reason}")]
InvalidValue { name: &'static str, reason: String },
}
mod unchecked {
use super::*;
#[derive(Deserialize)]
pub(super) struct MpcTlsConfigUnchecked {
max_sent_data: usize,
max_sent_records: Option<usize>,
max_recv_data_online: usize,
max_recv_data: usize,
max_recv_records_online: Option<usize>,
defer_decryption_from_start: bool,
network: NetworkSetting,
}
impl TryFrom<MpcTlsConfigUnchecked> for MpcTlsConfig {
type Error = MpcTlsConfigError;
fn try_from(value: MpcTlsConfigUnchecked) -> Result<Self, Self::Error> {
validate(MpcTlsConfig {
max_sent_data: value.max_sent_data,
max_sent_records: value.max_sent_records,
max_recv_data_online: value.max_recv_data_online,
max_recv_data: value.max_recv_data,
max_recv_records_online: value.max_recv_records_online,
defer_decryption_from_start: value.defer_decryption_from_start,
network: value.network,
})
}
}
}

View File

@@ -1,56 +0,0 @@
//! Verifier configuration.
use serde::{Deserialize, Serialize};
use crate::webpki::RootCertStore;
/// Verifier configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerifierConfig {
root_store: RootCertStore,
}
impl VerifierConfig {
/// Creates a new builder.
pub fn builder() -> VerifierConfigBuilder {
VerifierConfigBuilder::default()
}
/// Returns the root certificate store.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
}
}
/// Builder for [`VerifierConfig`].
#[derive(Debug, Default)]
pub struct VerifierConfigBuilder {
root_store: Option<RootCertStore>,
}
impl VerifierConfigBuilder {
/// Sets the root certificate store.
pub fn root_store(mut self, root_store: RootCertStore) -> Self {
self.root_store = Some(root_store);
self
}
/// Builds the configuration.
pub fn build(self) -> Result<VerifierConfig, VerifierConfigError> {
let root_store = self
.root_store
.ok_or(ErrorRepr::MissingField { name: "root_store" })?;
Ok(VerifierConfig { root_store })
}
}
/// Error for [`VerifierConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct VerifierConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
}

View File

@@ -116,75 +116,84 @@ pub enum KeyType {
SECP256R1 = 0x0017,
}
/// Signature algorithm used on the key exchange parameters.
/// Signature scheme on the key exchange parameters.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[allow(non_camel_case_types, missing_docs)]
pub enum SignatureAlgorithm {
ECDSA_NISTP256_SHA256,
ECDSA_NISTP256_SHA384,
ECDSA_NISTP384_SHA256,
ECDSA_NISTP384_SHA384,
ED25519,
RSA_PKCS1_2048_8192_SHA256,
RSA_PKCS1_2048_8192_SHA384,
RSA_PKCS1_2048_8192_SHA512,
RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
pub enum SignatureScheme {
RSA_PKCS1_SHA1 = 0x0201,
ECDSA_SHA1_Legacy = 0x0203,
RSA_PKCS1_SHA256 = 0x0401,
ECDSA_NISTP256_SHA256 = 0x0403,
RSA_PKCS1_SHA384 = 0x0501,
ECDSA_NISTP384_SHA384 = 0x0503,
RSA_PKCS1_SHA512 = 0x0601,
ECDSA_NISTP521_SHA512 = 0x0603,
RSA_PSS_SHA256 = 0x0804,
RSA_PSS_SHA384 = 0x0805,
RSA_PSS_SHA512 = 0x0806,
ED25519 = 0x0807,
}
impl fmt::Display for SignatureAlgorithm {
impl fmt::Display for SignatureScheme {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SignatureAlgorithm::ECDSA_NISTP256_SHA256 => write!(f, "ECDSA_NISTP256_SHA256"),
SignatureAlgorithm::ECDSA_NISTP256_SHA384 => write!(f, "ECDSA_NISTP256_SHA384"),
SignatureAlgorithm::ECDSA_NISTP384_SHA256 => write!(f, "ECDSA_NISTP384_SHA256"),
SignatureAlgorithm::ECDSA_NISTP384_SHA384 => write!(f, "ECDSA_NISTP384_SHA384"),
SignatureAlgorithm::ED25519 => write!(f, "ED25519"),
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256 => {
write!(f, "RSA_PKCS1_2048_8192_SHA256")
}
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA384 => {
write!(f, "RSA_PKCS1_2048_8192_SHA384")
}
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA512 => {
write!(f, "RSA_PKCS1_2048_8192_SHA512")
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA256_LEGACY_KEY => {
write!(f, "RSA_PSS_2048_8192_SHA256_LEGACY_KEY")
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA384_LEGACY_KEY => {
write!(f, "RSA_PSS_2048_8192_SHA384_LEGACY_KEY")
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA512_LEGACY_KEY => {
write!(f, "RSA_PSS_2048_8192_SHA512_LEGACY_KEY")
}
SignatureScheme::RSA_PKCS1_SHA1 => write!(f, "RSA_PKCS1_SHA1"),
SignatureScheme::ECDSA_SHA1_Legacy => write!(f, "ECDSA_SHA1_Legacy"),
SignatureScheme::RSA_PKCS1_SHA256 => write!(f, "RSA_PKCS1_SHA256"),
SignatureScheme::ECDSA_NISTP256_SHA256 => write!(f, "ECDSA_NISTP256_SHA256"),
SignatureScheme::RSA_PKCS1_SHA384 => write!(f, "RSA_PKCS1_SHA384"),
SignatureScheme::ECDSA_NISTP384_SHA384 => write!(f, "ECDSA_NISTP384_SHA384"),
SignatureScheme::RSA_PKCS1_SHA512 => write!(f, "RSA_PKCS1_SHA512"),
SignatureScheme::ECDSA_NISTP521_SHA512 => write!(f, "ECDSA_NISTP521_SHA512"),
SignatureScheme::RSA_PSS_SHA256 => write!(f, "RSA_PSS_SHA256"),
SignatureScheme::RSA_PSS_SHA384 => write!(f, "RSA_PSS_SHA384"),
SignatureScheme::RSA_PSS_SHA512 => write!(f, "RSA_PSS_SHA512"),
SignatureScheme::ED25519 => write!(f, "ED25519"),
}
}
}
impl From<tls_core::verify::SignatureAlgorithm> for SignatureAlgorithm {
fn from(value: tls_core::verify::SignatureAlgorithm) -> Self {
use tls_core::verify::SignatureAlgorithm as Core;
impl TryFrom<tls_core::msgs::enums::SignatureScheme> for SignatureScheme {
type Error = &'static str;
fn try_from(value: tls_core::msgs::enums::SignatureScheme) -> Result<Self, Self::Error> {
use tls_core::msgs::enums::SignatureScheme as Core;
use SignatureScheme::*;
Ok(match value {
Core::RSA_PKCS1_SHA1 => RSA_PKCS1_SHA1,
Core::ECDSA_SHA1_Legacy => ECDSA_SHA1_Legacy,
Core::RSA_PKCS1_SHA256 => RSA_PKCS1_SHA256,
Core::ECDSA_NISTP256_SHA256 => ECDSA_NISTP256_SHA256,
Core::RSA_PKCS1_SHA384 => RSA_PKCS1_SHA384,
Core::ECDSA_NISTP384_SHA384 => ECDSA_NISTP384_SHA384,
Core::RSA_PKCS1_SHA512 => RSA_PKCS1_SHA512,
Core::ECDSA_NISTP521_SHA512 => ECDSA_NISTP521_SHA512,
Core::RSA_PSS_SHA256 => RSA_PSS_SHA256,
Core::RSA_PSS_SHA384 => RSA_PSS_SHA384,
Core::RSA_PSS_SHA512 => RSA_PSS_SHA512,
Core::ED25519 => ED25519,
_ => return Err("unsupported signature scheme"),
})
}
}
impl From<SignatureScheme> for tls_core::msgs::enums::SignatureScheme {
fn from(value: SignatureScheme) -> Self {
use tls_core::msgs::enums::SignatureScheme::*;
match value {
Core::ECDSA_NISTP256_SHA256 => SignatureAlgorithm::ECDSA_NISTP256_SHA256,
Core::ECDSA_NISTP256_SHA384 => SignatureAlgorithm::ECDSA_NISTP256_SHA384,
Core::ECDSA_NISTP384_SHA256 => SignatureAlgorithm::ECDSA_NISTP384_SHA256,
Core::ECDSA_NISTP384_SHA384 => SignatureAlgorithm::ECDSA_NISTP384_SHA384,
Core::ED25519 => SignatureAlgorithm::ED25519,
Core::RSA_PKCS1_2048_8192_SHA256 => SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256,
Core::RSA_PKCS1_2048_8192_SHA384 => SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA384,
Core::RSA_PKCS1_2048_8192_SHA512 => SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA512,
Core::RSA_PSS_2048_8192_SHA256_LEGACY_KEY => {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA256_LEGACY_KEY
}
Core::RSA_PSS_2048_8192_SHA384_LEGACY_KEY => {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA384_LEGACY_KEY
}
Core::RSA_PSS_2048_8192_SHA512_LEGACY_KEY => {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA512_LEGACY_KEY
}
SignatureScheme::RSA_PKCS1_SHA1 => RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy => ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256 => RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256 => ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384 => RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384 => ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512 => RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512 => ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256 => RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384 => RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512 => RSA_PSS_SHA512,
SignatureScheme::ED25519 => ED25519,
}
}
}
@@ -192,8 +201,8 @@ impl From<tls_core::verify::SignatureAlgorithm> for SignatureAlgorithm {
/// Server's signature of the key exchange parameters.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerSignature {
/// Signature algorithm.
pub alg: SignatureAlgorithm,
/// Signature scheme.
pub scheme: SignatureScheme,
/// Signature data.
pub sig: Vec<u8>,
}
@@ -350,23 +359,20 @@ impl HandshakeData {
message.extend_from_slice(&server_ephemeral_key.kx_params());
use webpki::ring as alg;
let sig_alg = match self.sig.alg {
SignatureAlgorithm::ECDSA_NISTP256_SHA256 => alg::ECDSA_P256_SHA256,
SignatureAlgorithm::ECDSA_NISTP256_SHA384 => alg::ECDSA_P256_SHA384,
SignatureAlgorithm::ECDSA_NISTP384_SHA256 => alg::ECDSA_P384_SHA256,
SignatureAlgorithm::ECDSA_NISTP384_SHA384 => alg::ECDSA_P384_SHA384,
SignatureAlgorithm::ED25519 => alg::ED25519,
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256 => alg::RSA_PKCS1_2048_8192_SHA256,
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA384 => alg::RSA_PKCS1_2048_8192_SHA384,
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA512 => alg::RSA_PKCS1_2048_8192_SHA512,
SignatureAlgorithm::RSA_PSS_2048_8192_SHA256_LEGACY_KEY => {
alg::RSA_PSS_2048_8192_SHA256_LEGACY_KEY
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA384_LEGACY_KEY => {
alg::RSA_PSS_2048_8192_SHA384_LEGACY_KEY
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA512_LEGACY_KEY => {
alg::RSA_PSS_2048_8192_SHA512_LEGACY_KEY
let sig_alg = match self.sig.scheme {
SignatureScheme::RSA_PKCS1_SHA256 => alg::RSA_PKCS1_2048_8192_SHA256,
SignatureScheme::RSA_PKCS1_SHA384 => alg::RSA_PKCS1_2048_8192_SHA384,
SignatureScheme::RSA_PKCS1_SHA512 => alg::RSA_PKCS1_2048_8192_SHA512,
SignatureScheme::RSA_PSS_SHA256 => alg::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA384 => alg::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA512 => alg::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
SignatureScheme::ECDSA_NISTP256_SHA256 => alg::ECDSA_P256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384 => alg::ECDSA_P384_SHA384,
SignatureScheme::ED25519 => alg::ED25519,
scheme => {
return Err(HandshakeVerificationError::UnsupportedSignatureScheme(
scheme,
))
}
};
@@ -396,6 +402,8 @@ pub enum HandshakeVerificationError {
InvalidServerEphemeralKey,
#[error("server certificate verification failed: {0}")]
ServerCert(ServerCertVerifierError),
#[error("unsupported signature scheme: {0}")]
UnsupportedSignatureScheme(SignatureScheme),
}
#[cfg(test)]

View File

@@ -1,11 +1,11 @@
use rangeset::set::RangeSet;
use rangeset::RangeSet;
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("{")?;
for range in self.0.iter() {
for range in self.0.iter_ranges() {
write!(f, "{}..{}", range.start, range.end)?;
if range.end < self.0.end().unwrap_or(0) {
f.write_str(", ")?;

View File

@@ -1,14 +1,20 @@
//! Fixtures for testing
mod provider;
pub mod transcript;
pub use provider::FixtureEncodingProvider;
use hex::FromHex;
use crate::{
connection::{
CertBinding, CertBindingV1_2, ConnectionInfo, DnsName, HandshakeData, KeyType,
ServerEphemKey, ServerName, ServerSignature, SignatureAlgorithm, TlsVersion,
TranscriptLength,
ServerEphemKey, ServerName, ServerSignature, SignatureScheme, TlsVersion, TranscriptLength,
},
transcript::{
encoding::{EncoderSecret, EncodingProvider},
Transcript,
},
webpki::CertificateDer,
};
@@ -41,7 +47,7 @@ impl ConnectionFixture {
CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()),
],
sig: ServerSignature {
alg: SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256,
scheme: SignatureScheme::RSA_PKCS1_SHA256,
sig: Vec::<u8>::from_hex(include_bytes!(
"fixtures/data/tlsnotary.org/signature"
))
@@ -86,7 +92,7 @@ impl ConnectionFixture {
CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()),
],
sig: ServerSignature {
alg: SignatureAlgorithm::ECDSA_NISTP256_SHA256,
scheme: SignatureScheme::ECDSA_NISTP256_SHA256,
sig: Vec::<u8>::from_hex(include_bytes!(
"fixtures/data/appliedzkp.org/signature"
))
@@ -122,3 +128,27 @@ impl ConnectionFixture {
server_ephemeral_key
}
}
/// Returns an encoding provider fixture.
pub fn encoding_provider(tx: &[u8], rx: &[u8]) -> impl EncodingProvider {
let secret = encoder_secret();
FixtureEncodingProvider::new(&secret, Transcript::new(tx, rx))
}
/// Seed fixture.
const SEED: [u8; 32] = [0; 32];
/// Delta fixture.
const DELTA: [u8; 16] = [1; 16];
/// Returns an encoder secret fixture.
pub fn encoder_secret() -> EncoderSecret {
EncoderSecret::new(SEED, DELTA)
}
/// Returns a tampered encoder secret fixture.
pub fn encoder_secret_tampered_seed() -> EncoderSecret {
let mut seed = SEED;
seed[0] += 1;
EncoderSecret::new(seed, DELTA)
}

View File

@@ -0,0 +1,41 @@
use std::ops::Range;
use crate::transcript::{
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider, EncodingProviderError},
Direction, Transcript,
};
/// A encoding provider fixture.
pub struct FixtureEncodingProvider {
encoder: Box<dyn Encoder>,
transcript: Transcript,
}
impl FixtureEncodingProvider {
/// Creates a new encoding provider fixture.
pub(crate) fn new(secret: &EncoderSecret, transcript: Transcript) -> Self {
Self {
encoder: Box::new(new_encoder(secret)),
transcript,
}
}
}
impl EncodingProvider for FixtureEncodingProvider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let transcript = match direction {
Direction::Sent => &self.transcript.sent(),
Direction::Received => &self.transcript.received(),
};
let data = transcript.get(range.clone()).ok_or(EncodingProviderError)?;
self.encoder.encode_data(direction, range, data, dest);
Ok(())
}
}

View File

@@ -2,13 +2,12 @@
use aead::Payload as AeadPayload;
use aes_gcm::{aead::Aead, Aes128Gcm, NewAead};
#[allow(deprecated)]
use generic_array::GenericArray;
use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_core::msgs::{
base::Payload,
codec::Codec,
enums::{HandshakeType, ProtocolVersion},
enums::{ContentType, HandshakeType, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload},
message::{OpaqueMessage, PlainMessage},
};
@@ -16,7 +15,7 @@ use tls_core::msgs::{
use crate::{
connection::{TranscriptLength, VerifyData},
fixtures::ConnectionFixture,
transcript::{ContentType, Record, TlsTranscript},
transcript::{Record, TlsTranscript},
};
/// The key used for encryption of the sent and received transcript.
@@ -104,7 +103,7 @@ impl TranscriptGenerator {
let explicit_nonce: [u8; 8] = seq.to_be_bytes();
let msg = PlainMessage {
typ: ContentType::ApplicationData.into(),
typ: ContentType::ApplicationData,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext),
};
@@ -139,7 +138,7 @@ impl TranscriptGenerator {
handshake_message.encode(&mut plaintext);
let msg = PlainMessage {
typ: ContentType::Handshake.into(),
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext.clone()),
};
@@ -181,7 +180,6 @@ fn aes_gcm_encrypt(
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&iv);
nonce[4..].copy_from_slice(&explicit_nonce);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&key).unwrap();

View File

@@ -296,14 +296,14 @@ mod sha2 {
fn hash(&self, data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default();
hasher.update(data);
super::Hash::new(hasher.finalize().as_ref())
super::Hash::new(hasher.finalize().as_slice())
}
fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default();
hasher.update(prefix);
hasher.update(data);
super::Hash::new(hasher.finalize().as_ref())
super::Hash::new(hasher.finalize().as_slice())
}
}
}

View File

@@ -12,16 +12,196 @@ pub mod merkle;
pub mod transcript;
pub mod webpki;
pub use rangeset;
pub mod config;
pub(crate) mod display;
use rangeset::{RangeSet, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
connection::ServerName,
transcript::{PartialTranscript, TranscriptCommitment, TranscriptSecret},
connection::{HandshakeData, ServerName},
transcript::{
Direction, PartialTranscript, Transcript, TranscriptCommitConfig, TranscriptCommitRequest,
TranscriptCommitment, TranscriptSecret,
},
};
/// Configuration to prove information to the verifier.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveConfig {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl ProveConfig {
/// Creates a new builder.
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
ProveConfigBuilder::new(transcript)
}
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the ranges of the transcript to be revealed.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
self.transcript_commit.as_ref()
}
}
/// Builder for [`ProveConfig`].
#[derive(Debug)]
pub struct ProveConfigBuilder<'a> {
transcript: &'a Transcript,
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl<'a> ProveConfigBuilder<'a> {
/// Creates a new builder.
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
server_identity: false,
reveal: None,
transcript_commit: None,
}
}
/// Proves the server identity.
pub fn server_identity(&mut self) -> &mut Self {
self.server_identity = true;
self
}
/// Configures transcript commitments.
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
self.transcript_commit = Some(transcript_commit);
self
}
/// Reveals the given ranges of the transcript.
pub fn reveal(
&mut self,
direction: Direction,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
let idx = ranges.to_range_set();
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(ProveConfigBuilderError(
ProveConfigBuilderErrorRepr::IndexOutOfBounds {
direction,
actual: idx.end().unwrap_or(0),
len: self.transcript.len_of_direction(direction),
},
));
}
let (sent, recv) = self.reveal.get_or_insert_default();
match direction {
Direction::Sent => sent.union_mut(&idx),
Direction::Received => recv.union_mut(&idx),
}
Ok(self)
}
/// Reveals the given ranges of the sent data transcript.
pub fn reveal_sent(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
self.reveal(Direction::Sent, ranges)
}
/// Reveals the given ranges of the received data transcript.
pub fn reveal_recv(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
self.reveal(Direction::Received, ranges)
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigBuilderError> {
Ok(ProveConfig {
server_identity: self.server_identity,
reveal: self.reveal,
transcript_commit: self.transcript_commit,
})
}
}
/// Error for [`ProveConfigBuilder`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProveConfigBuilderError(#[from] ProveConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ProveConfigBuilderErrorRepr {
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
IndexOutOfBounds {
direction: Direction,
actual: usize,
len: usize,
},
}
/// Configuration to verify information from the prover.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct VerifyConfig {}
impl VerifyConfig {
/// Creates a new builder.
pub fn builder() -> VerifyConfigBuilder {
VerifyConfigBuilder::new()
}
}
/// Builder for [`VerifyConfig`].
#[derive(Debug, Default)]
pub struct VerifyConfigBuilder {}
impl VerifyConfigBuilder {
/// Creates a new builder.
pub fn new() -> Self {
Self {}
}
/// Builds the configuration.
pub fn build(self) -> Result<VerifyConfig, VerifyConfigBuilderError> {
Ok(VerifyConfig {})
}
}
/// Error for [`VerifyConfigBuilder`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum VerifyConfigBuilderErrorRepr {}
/// Payload sent to the verifier.
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvePayload {
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.
pub transcript: Option<PartialTranscript>,
/// Transcript commitment configuration.
pub transcript_commit: Option<TranscriptCommitRequest>,
}
/// Prover output.
#[derive(Serialize, Deserialize)]
pub struct ProverOutput {

View File

@@ -63,6 +63,11 @@ impl MerkleProof {
Ok(())
}
/// Returns the leaf count of the Merkle tree associated with the proof.
pub(crate) fn leaf_count(&self) -> usize {
self.leaf_count
}
}
#[derive(Clone)]

View File

@@ -19,17 +19,14 @@
//! withheld.
mod commit;
pub mod encoding;
pub mod hash;
mod proof;
mod tls;
use std::{fmt, ops::Range};
use rangeset::{
iter::RangeIterator,
ops::{Index, Set},
set::RangeSet,
};
use rangeset::{Difference, IndexRanges, RangeSet, Union};
use serde::{Deserialize, Serialize};
use crate::connection::TranscriptLength;
@@ -41,7 +38,8 @@ pub use commit::{
pub use proof::{
TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError,
};
pub use tls::{ContentType, Record, TlsTranscript};
pub use tls::{Record, TlsTranscript};
pub use tls_core::msgs::enums::ContentType;
/// A transcript contains the plaintext of all application data communicated
/// between the Prover and the Server.
@@ -109,14 +107,8 @@ impl Transcript {
}
Some(
Subsequence::new(
idx.clone(),
data.index(idx).fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
)
.expect("data is same length as index"),
Subsequence::new(idx.clone(), data.index_ranges(idx))
.expect("data is same length as index"),
)
}
@@ -138,11 +130,11 @@ impl Transcript {
let mut sent = vec![0; self.sent.len()];
let mut received = vec![0; self.received.len()];
for range in sent_idx.iter() {
for range in sent_idx.iter_ranges() {
sent[range.clone()].copy_from_slice(&self.sent[range]);
}
for range in recv_idx.iter() {
for range in recv_idx.iter_ranges() {
received[range.clone()].copy_from_slice(&self.received[range]);
}
@@ -195,20 +187,12 @@ pub struct CompressedPartialTranscript {
impl From<PartialTranscript> for CompressedPartialTranscript {
fn from(uncompressed: PartialTranscript) -> Self {
Self {
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
Vec::new(),
|mut acc, s| {
acc.extend_from_slice(s);
acc
},
),
sent_authed: uncompressed
.sent
.index_ranges(&uncompressed.sent_authed_idx),
received_authed: uncompressed
.received
.index(&uncompressed.received_authed_idx)
.fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
.index_ranges(&uncompressed.received_authed_idx),
sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(),
@@ -224,7 +208,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.sent_idx.iter() {
for range in compressed.sent_idx.iter_ranges() {
sent[range.clone()]
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
offset += range.len();
@@ -232,7 +216,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.recv_idx.iter() {
for range in compressed.recv_idx.iter_ranges() {
received[range.clone()]
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
offset += range.len();
@@ -321,16 +305,12 @@ impl PartialTranscript {
/// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> RangeSet<usize> {
(0..self.sent.len())
.difference(&self.sent_authed_idx)
.into_set()
(0..self.sent.len()).difference(&self.sent_authed_idx)
}
/// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> RangeSet<usize> {
(0..self.received.len())
.difference(&self.received_authed_idx)
.into_set()
(0..self.received.len()).difference(&self.received_authed_idx)
}
/// Returns an iterator over the authenticated data in the transcript.
@@ -340,7 +320,7 @@ impl PartialTranscript {
Direction::Received => (&self.received, &self.received_authed_idx),
};
authed.iter_values().map(move |i| data[i])
authed.iter().map(|i| data[i])
}
/// Unions the authenticated data of this transcript with another.
@@ -360,20 +340,24 @@ impl PartialTranscript {
"received data are not the same length"
);
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
for range in other
.sent_authed_idx
.difference(&self.sent_authed_idx)
.iter_ranges()
{
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
}
for range in other
.received_authed_idx
.difference(&self.received_authed_idx)
.iter_ranges()
{
self.received[range.clone()].copy_from_slice(&other.received[range]);
}
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
self.received_authed_idx
.union_mut(&other.received_authed_idx);
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
}
/// Unions an authenticated subsequence into this transcript.
@@ -385,11 +369,11 @@ impl PartialTranscript {
match direction {
Direction::Sent => {
seq.copy_to(&mut self.sent);
self.sent_authed_idx.union_mut(&seq.idx);
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
}
Direction::Received => {
seq.copy_to(&mut self.received);
self.received_authed_idx.union_mut(&seq.idx);
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
}
}
}
@@ -400,10 +384,10 @@ impl PartialTranscript {
///
/// * `value` - The value to set the unauthenticated bytes to
pub fn set_unauthed(&mut self, value: u8) {
for range in self.sent_unauthed().iter() {
for range in self.sent_unauthed().iter_ranges() {
self.sent[range].fill(value);
}
for range in self.received_unauthed().iter() {
for range in self.received_unauthed().iter_ranges() {
self.received[range].fill(value);
}
}
@@ -418,13 +402,13 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction {
Direction::Sent => {
for r in range.difference(&self.sent_authed_idx) {
self.sent[r].fill(value);
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
self.sent[range].fill(value);
}
}
Direction::Received => {
for r in range.difference(&self.received_authed_idx) {
self.received[r].fill(value);
for range in range.difference(&self.received_authed_idx).iter_ranges() {
self.received[range].fill(value);
}
}
}
@@ -502,7 +486,7 @@ impl Subsequence {
/// Panics if the subsequence ranges are out of bounds.
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
let mut offset = 0;
for range in self.idx.iter() {
for range in self.idx.iter_ranges() {
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
offset += range.len();
}
@@ -627,7 +611,12 @@ mod validation {
mut partial_transcript: CompressedPartialTranscriptUnchecked,
) {
// Change the total to be less than the last range's end bound.
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
let end = partial_transcript
.sent_idx
.iter_ranges()
.next_back()
.unwrap()
.end;
partial_transcript.sent_total = end - 1;

View File

@@ -2,21 +2,33 @@
use std::{collections::HashSet, fmt};
use rangeset::set::ToRangeSet;
use rangeset::ToRangeSet;
use serde::{Deserialize, Serialize};
use crate::{
hash::HashAlgId,
transcript::{
encoding::{EncodingCommitment, EncodingTree},
hash::{PlaintextHash, PlaintextHashSecret},
Direction, RangeSet, Transcript,
},
};
/// The maximum allowed total bytelength of committed data for a single
/// commitment kind. Used to prevent DoS during verification. (May cause the
/// verifier to hash up to a max of 1GB * 128 = 128GB of data for certain kinds
/// of encoding commitments.)
///
/// This value must not exceed bcs's MAX_SEQUENCE_LENGTH limit (which is (1 <<
/// 31) - 1 by default)
pub(crate) const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000;
/// Kind of transcript commitment.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TranscriptCommitmentKind {
/// A commitment to encodings of the transcript.
Encoding,
/// A hash commitment to plaintext in the transcript.
Hash {
/// The hash algorithm used.
@@ -27,6 +39,7 @@ pub enum TranscriptCommitmentKind {
impl fmt::Display for TranscriptCommitmentKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Encoding => f.write_str("encoding"),
Self::Hash { alg } => write!(f, "hash ({alg})"),
}
}
@@ -36,6 +49,8 @@ impl fmt::Display for TranscriptCommitmentKind {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TranscriptCommitment {
/// Encoding commitment.
Encoding(EncodingCommitment),
/// Plaintext hash commitment.
Hash(PlaintextHash),
}
@@ -44,6 +59,8 @@ pub enum TranscriptCommitment {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TranscriptSecret {
/// Encoding tree.
Encoding(EncodingTree),
/// Plaintext hash secret.
Hash(PlaintextHashSecret),
}
@@ -51,6 +68,9 @@ pub enum TranscriptSecret {
/// Configuration for transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitConfig {
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
@@ -60,23 +80,41 @@ impl TranscriptCommitConfig {
TranscriptCommitConfigBuilder::new(transcript)
}
/// Returns the hash algorithm to use for encoding commitments.
pub fn encoding_hash_alg(&self) -> &HashAlgId {
&self.encoding_hash_alg
}
/// Returns `true` if the configuration has any encoding commitments.
pub fn has_encoding(&self) -> bool {
self.has_encoding
}
/// Returns `true` if the configuration has any hash commitments.
pub fn has_hash(&self) -> bool {
self.commits
.iter()
.any(|(_, kind)| matches!(kind, TranscriptCommitmentKind::Hash { .. }))
self.has_hash
}
/// Returns an iterator over the encoding commitment indices.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.commits.iter().filter_map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Encoding => Some(idx),
_ => None,
})
}
/// Returns an iterator over the hash commitment indices.
pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, RangeSet<usize>), &HashAlgId)> {
self.commits.iter().map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Hash { alg } => (idx, alg),
self.commits.iter().filter_map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
_ => None,
})
}
/// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest {
encoding: self.has_encoding,
hash: self
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -86,9 +124,15 @@ impl TranscriptCommitConfig {
}
/// A builder for [`TranscriptCommitConfig`].
///
/// The default hash algorithm is [`HashAlgId::BLAKE3`] and the default kind
/// is [`TranscriptCommitmentKind::Encoding`].
#[derive(Debug)]
pub struct TranscriptCommitConfigBuilder<'a> {
transcript: &'a Transcript,
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
default_kind: TranscriptCommitmentKind,
commits: HashSet<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
@@ -98,13 +142,20 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
default_kind: TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3,
},
encoding_hash_alg: HashAlgId::BLAKE3,
has_encoding: false,
has_hash: false,
default_kind: TranscriptCommitmentKind::Encoding,
commits: HashSet::default(),
}
}
/// Sets the hash algorithm to use for encoding commitments.
pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self {
self.encoding_hash_alg = alg;
self
}
/// Sets the default kind of commitment to use.
pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self {
self.default_kind = default_kind;
@@ -138,6 +189,11 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
));
}
match kind {
TranscriptCommitmentKind::Encoding => self.has_encoding = true,
TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
}
self.commits.insert(((direction, idx), kind));
Ok(self)
@@ -184,6 +240,9 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
/// Builds the configuration.
pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
Ok(TranscriptCommitConfig {
encoding_hash_alg: self.encoding_hash_alg,
has_encoding: self.has_encoding,
has_hash: self.has_hash,
commits: Vec::from_iter(self.commits),
})
}
@@ -230,10 +289,16 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest {
encoding: bool,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
}
impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested.
pub fn encoding(&self) -> bool {
self.encoding
}
/// Returns `true` if a hash commitment is requested.
pub fn has_hash(&self) -> bool {
!self.hash.is_empty()

View File

@@ -0,0 +1,24 @@
//! Transcript encoding commitments and proofs.
mod encoder;
mod proof;
mod provider;
mod tree;
pub use encoder::{new_encoder, Encoder, EncoderSecret};
pub use proof::{EncodingProof, EncodingProofError};
pub use provider::{EncodingProvider, EncodingProviderError};
pub use tree::{EncodingTree, EncodingTreeError};
use serde::{Deserialize, Serialize};
use crate::hash::TypedHash;
/// Transcript encoding commitment.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncodingCommitment {
/// Merkle root of the encoding commitments.
pub root: TypedHash,
/// Seed used to generate the encodings.
pub secret: EncoderSecret,
}

View File

@@ -0,0 +1,137 @@
use std::ops::Range;
use crate::transcript::Direction;
use itybity::ToBits;
use rand::{RngCore, SeedableRng};
use rand_chacha::ChaCha12Rng;
use serde::{Deserialize, Serialize};
/// The size of the encoding for 1 bit, in bytes.
const BIT_ENCODING_SIZE: usize = 16;
/// The size of the encoding for 1 byte, in bytes.
const BYTE_ENCODING_SIZE: usize = 128;
/// Secret used by an encoder to generate encodings.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncoderSecret {
seed: [u8; 32],
delta: [u8; BIT_ENCODING_SIZE],
}
opaque_debug::implement!(EncoderSecret);
impl EncoderSecret {
/// Creates a new secret.
///
/// # Arguments
///
/// * `seed` - The seed for the PRG.
/// * `delta` - Delta for deriving the one-encodings.
pub fn new(seed: [u8; 32], delta: [u8; 16]) -> Self {
Self { seed, delta }
}
/// Returns the seed.
pub fn seed(&self) -> &[u8; 32] {
&self.seed
}
/// Returns the delta.
pub fn delta(&self) -> &[u8; 16] {
&self.delta
}
}
/// Creates a new encoder.
pub fn new_encoder(secret: &EncoderSecret) -> impl Encoder {
ChaChaEncoder::new(secret)
}
pub(crate) struct ChaChaEncoder {
seed: [u8; 32],
delta: [u8; 16],
}
impl ChaChaEncoder {
pub(crate) fn new(secret: &EncoderSecret) -> Self {
let seed = *secret.seed();
let delta = *secret.delta();
Self { seed, delta }
}
pub(crate) fn new_prg(&self, stream_id: u64) -> ChaCha12Rng {
let mut prg = ChaCha12Rng::from_seed(self.seed);
prg.set_stream(stream_id);
prg.set_word_pos(0);
prg
}
}
/// A transcript encoder.
///
/// This is an internal implementation detail that should not be exposed to the
/// public API.
pub trait Encoder {
/// Writes the zero encoding for the given range of the transcript into the
/// destination buffer.
fn encode_range(&self, direction: Direction, range: Range<usize>, dest: &mut Vec<u8>);
/// Writes the encoding for the given data into the destination buffer.
fn encode_data(
&self,
direction: Direction,
range: Range<usize>,
data: &[u8],
dest: &mut Vec<u8>,
);
}
impl Encoder for ChaChaEncoder {
fn encode_range(&self, direction: Direction, range: Range<usize>, dest: &mut Vec<u8>) {
// ChaCha encoder works with 32-bit words. Each encoded bit is 128 bits long.
const WORDS_PER_BYTE: u128 = 8 * 128 / 32;
let stream_id: u64 = match direction {
Direction::Sent => 0,
Direction::Received => 1,
};
let mut prg = self.new_prg(stream_id);
let len = range.len() * BYTE_ENCODING_SIZE;
let pos = dest.len();
// Write 0s to the destination buffer.
dest.resize(pos + len, 0);
// Fill the destination buffer with the PRG.
prg.set_word_pos(range.start as u128 * WORDS_PER_BYTE);
prg.fill_bytes(&mut dest[pos..pos + len]);
}
fn encode_data(
&self,
direction: Direction,
range: Range<usize>,
data: &[u8],
dest: &mut Vec<u8>,
) {
const ZERO: [u8; 16] = [0; BIT_ENCODING_SIZE];
let pos = dest.len();
// Write the zero encoding for the given range.
self.encode_range(direction, range, dest);
let dest = &mut dest[pos..];
for (pos, bit) in data.iter_lsb0().enumerate() {
// Add the delta to the encoding whenever the encoded bit is 1,
// otherwise add a zero.
let summand = if bit { &self.delta } else { &ZERO };
dest[pos * BIT_ENCODING_SIZE..(pos + 1) * BIT_ENCODING_SIZE]
.iter_mut()
.zip(summand)
.for_each(|(a, b)| *a ^= *b);
}
}
}

View File

@@ -0,0 +1,357 @@
use std::{collections::HashMap, fmt};
use rangeset::{RangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
hash::{Blinder, HashProvider, HashProviderError},
merkle::{MerkleError, MerkleProof},
transcript::{
commit::MAX_TOTAL_COMMITTED_DATA,
encoding::{new_encoder, Encoder, EncodingCommitment},
Direction,
},
};
/// An opening of a leaf in the encoding tree.
#[derive(Clone, Serialize, Deserialize)]
pub(super) struct Opening {
pub(super) direction: Direction,
pub(super) idx: RangeSet<usize>,
pub(super) blinder: Blinder,
}
opaque_debug::implement!(Opening);
/// An encoding commitment proof.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(try_from = "validation::EncodingProofUnchecked")]
pub struct EncodingProof {
/// The proof of inclusion of the commitment(s) in the Merkle tree of
/// commitments.
pub(super) inclusion_proof: MerkleProof,
pub(super) openings: HashMap<usize, Opening>,
}
impl EncodingProof {
/// Verifies the proof against the commitment.
///
/// Returns the authenticated indices of the sent and received data,
/// respectively.
///
/// # Arguments
///
/// * `provider` - Hash provider.
/// * `commitment` - Encoding commitment to verify against.
/// * `sent` - Sent data to authenticate.
/// * `recv` - Received data to authenticate.
pub fn verify_with_provider(
&self,
provider: &HashProvider,
commitment: &EncodingCommitment,
sent: &[u8],
recv: &[u8],
) -> Result<(RangeSet<usize>, RangeSet<usize>), EncodingProofError> {
let hasher = provider.get(&commitment.root.alg)?;
let encoder = new_encoder(&commitment.secret);
let Self {
inclusion_proof,
openings,
} = self;
let mut leaves = Vec::with_capacity(openings.len());
let mut expected_leaf = Vec::default();
let mut total_opened = 0u128;
let mut auth_sent = RangeSet::default();
let mut auth_recv = RangeSet::default();
for (
id,
Opening {
direction,
idx,
blinder,
},
) in openings
{
// Make sure the amount of data being proved is bounded.
total_opened += idx.len() as u128;
if total_opened > MAX_TOTAL_COMMITTED_DATA as u128 {
return Err(EncodingProofError::new(
ErrorKind::Proof,
"exceeded maximum allowed data",
))?;
}
let (data, auth) = match direction {
Direction::Sent => (sent, &mut auth_sent),
Direction::Received => (recv, &mut auth_recv),
};
// Make sure the ranges are within the bounds of the transcript.
if idx.end().unwrap_or(0) > data.len() {
return Err(EncodingProofError::new(
ErrorKind::Proof,
format!(
"index out of bounds of the transcript ({}): {} > {}",
direction,
idx.end().unwrap_or(0),
data.len()
),
));
}
expected_leaf.clear();
for range in idx.iter_ranges() {
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
}
expected_leaf.extend_from_slice(blinder.as_bytes());
// Compute the expected hash of the commitment to make sure it is
// present in the merkle tree.
leaves.push((*id, hasher.hash(&expected_leaf)));
auth.union_mut(idx);
}
// Verify that the expected hashes are present in the merkle tree.
//
// This proves the Prover committed to the purported data prior to the encoder
// seed being revealed. Ergo, if the encodings are authentic then the purported
// data is authentic.
inclusion_proof.verify(hasher, &commitment.root, leaves)?;
Ok((auth_sent, auth_recv))
}
}
/// Error for [`EncodingProof`].
#[derive(Debug, thiserror::Error)]
pub struct EncodingProofError {
kind: ErrorKind,
source: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl EncodingProofError {
fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self {
kind,
source: Some(source.into()),
}
}
}
#[derive(Debug)]
enum ErrorKind {
Provider,
Proof,
}
impl fmt::Display for EncodingProofError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("encoding proof error: ")?;
match self.kind {
ErrorKind::Provider => f.write_str("provider error")?,
ErrorKind::Proof => f.write_str("proof error")?,
}
if let Some(source) = &self.source {
write!(f, " caused by: {source}")?;
}
Ok(())
}
}
impl From<HashProviderError> for EncodingProofError {
fn from(error: HashProviderError) -> Self {
Self::new(ErrorKind::Provider, error)
}
}
impl From<MerkleError> for EncodingProofError {
fn from(error: MerkleError) -> Self {
Self::new(ErrorKind::Proof, error)
}
}
/// Invalid encoding proof error.
#[derive(Debug, thiserror::Error)]
#[error("invalid encoding proof: {0}")]
pub struct InvalidEncodingProof(&'static str);
mod validation {
use super::*;
/// The maximum allowed height of the Merkle tree of encoding commitments.
///
/// The statistical security parameter (SSP) of the encoding commitment
/// protocol is calculated as "the number of uniformly random bits in a
/// single bit's encoding minus `MAX_HEIGHT`".
///
/// For example, a bit encoding used in garbled circuits typically has 127
/// uniformly random bits, hence when using it in the encoding
/// commitment protocol, the SSP is 127 - 30 = 97 bits.
///
/// Leaving this validation here as a fail-safe in case we ever start
/// using shorter encodings.
const MAX_HEIGHT: usize = 30;
#[derive(Debug, Deserialize)]
pub(super) struct EncodingProofUnchecked {
inclusion_proof: MerkleProof,
openings: HashMap<usize, Opening>,
}
impl TryFrom<EncodingProofUnchecked> for EncodingProof {
type Error = InvalidEncodingProof;
fn try_from(unchecked: EncodingProofUnchecked) -> Result<Self, Self::Error> {
if unchecked.inclusion_proof.leaf_count() > 1 << MAX_HEIGHT {
return Err(InvalidEncodingProof(
"the height of the tree exceeds the maximum allowed",
));
}
Ok(Self {
inclusion_proof: unchecked.inclusion_proof,
openings: unchecked.openings,
})
}
}
}
#[cfg(test)]
mod test {
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
use crate::{
fixtures::{encoder_secret, encoder_secret_tampered_seed, encoding_provider},
hash::Blake3,
transcript::{
encoding::{EncoderSecret, EncodingTree},
Transcript,
},
};
use super::*;
struct EncodingFixture {
transcript: Transcript,
proof: EncodingProof,
commitment: EncodingCommitment,
}
fn new_encoding_fixture(secret: EncoderSecret) -> EncodingFixture {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
let provider = encoding_provider(transcript.sent(), transcript.received());
let tree = EncodingTree::new(&Blake3::default(), [&idx_0, &idx_1], &provider).unwrap();
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
let commitment = EncodingCommitment {
root: tree.root(),
secret,
};
EncodingFixture {
transcript,
proof,
commitment,
}
}
#[test]
fn test_verify_encoding_proof_tampered_seed() {
let EncodingFixture {
transcript,
proof,
commitment,
} = new_encoding_fixture(encoder_secret_tampered_seed());
let err = proof
.verify_with_provider(
&HashProvider::default(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
#[test]
fn test_verify_encoding_proof_out_of_range() {
let EncodingFixture {
transcript,
proof,
commitment,
} = new_encoding_fixture(encoder_secret());
let sent = &transcript.sent()[transcript.sent().len() - 1..];
let recv = &transcript.received()[transcript.received().len() - 2..];
let err = proof
.verify_with_provider(&HashProvider::default(), &commitment, sent, recv)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
#[test]
fn test_verify_encoding_proof_tampered_idx() {
let EncodingFixture {
transcript,
mut proof,
commitment,
} = new_encoding_fixture(encoder_secret());
let Opening { idx, .. } = proof.openings.values_mut().next().unwrap();
*idx = RangeSet::from([0..3, 13..15]);
let err = proof
.verify_with_provider(
&HashProvider::default(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
#[test]
fn test_verify_encoding_proof_tampered_encoding_blinder() {
let EncodingFixture {
transcript,
mut proof,
commitment,
} = new_encoding_fixture(encoder_secret());
let Opening { blinder, .. } = proof.openings.values_mut().next().unwrap();
*blinder = rand::random();
let err = proof
.verify_with_provider(
&HashProvider::default(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
}
}

View File

@@ -0,0 +1,19 @@
use std::ops::Range;
use crate::transcript::Direction;
/// A provider of plaintext encodings.
pub trait EncodingProvider {
/// Writes the encoding of the given range into the destination buffer.
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError>;
}
/// Error for [`EncodingProvider`].
#[derive(Debug, thiserror::Error)]
#[error("failed to provide encoding")]
pub struct EncodingProviderError;

View File

@@ -0,0 +1,331 @@
use std::collections::HashMap;
use bimap::BiMap;
use rangeset::{RangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
merkle::MerkleTree,
transcript::{
encoding::{
proof::{EncodingProof, Opening},
EncodingProvider,
},
Direction,
},
};
/// Encoding tree builder error.
#[derive(Debug, thiserror::Error)]
pub enum EncodingTreeError {
/// Index is out of bounds of the transcript.
#[error("index is out of bounds of the transcript")]
OutOfBounds {
/// The index.
index: RangeSet<usize>,
/// The transcript length.
transcript_length: usize,
},
/// Encoding provider is missing an encoding for an index.
#[error("encoding provider is missing an encoding for an index")]
MissingEncoding {
/// The index which is missing.
index: RangeSet<usize>,
},
/// Index is missing from the tree.
#[error("index is missing from the tree")]
MissingLeaf {
/// The index which is missing.
index: RangeSet<usize>,
},
}
/// A merkle tree of transcript encodings.
#[derive(Clone, Serialize, Deserialize)]
pub struct EncodingTree {
/// Merkle tree of the commitments.
tree: MerkleTree,
/// Nonces used to blind the hashes.
blinders: Vec<Blinder>,
/// Mapping between the index of a leaf and the transcript index it
/// corresponds to.
idxs: BiMap<usize, (Direction, RangeSet<usize>)>,
/// Union of all transcript indices in the sent direction.
sent_idx: RangeSet<usize>,
/// Union of all transcript indices in the received direction.
received_idx: RangeSet<usize>,
}
opaque_debug::implement!(EncodingTree);
impl EncodingTree {
/// Creates a new encoding tree.
///
/// # Arguments
///
/// * `hasher` - The hash algorithm to use.
/// * `idxs` - The subsequence indices to commit to.
/// * `provider` - The encoding provider.
pub fn new<'idx>(
hasher: &dyn HashAlgorithm,
idxs: impl IntoIterator<Item = &'idx (Direction, RangeSet<usize>)>,
provider: &dyn EncodingProvider,
) -> Result<Self, EncodingTreeError> {
let mut this = Self {
tree: MerkleTree::new(hasher.id()),
blinders: Vec::new(),
idxs: BiMap::new(),
sent_idx: RangeSet::default(),
received_idx: RangeSet::default(),
};
let mut leaves = Vec::new();
let mut encoding = Vec::new();
for dir_idx in idxs {
let direction = dir_idx.0;
let idx = &dir_idx.1;
// Ignore empty indices.
if idx.is_empty() {
continue;
}
if this.idxs.contains_right(dir_idx) {
// The subsequence is already in the tree.
continue;
}
let blinder: Blinder = rand::random();
encoding.clear();
for range in idx.iter_ranges() {
provider
.provide_encoding(direction, range, &mut encoding)
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
}
encoding.extend_from_slice(blinder.as_bytes());
let leaf = hasher.hash(&encoding);
leaves.push(leaf);
this.blinders.push(blinder);
this.idxs.insert(this.idxs.len(), dir_idx.clone());
match direction {
Direction::Sent => this.sent_idx.union_mut(idx),
Direction::Received => this.received_idx.union_mut(idx),
}
}
this.tree.insert(hasher, leaves);
Ok(this)
}
/// Returns the root of the tree.
pub fn root(&self) -> TypedHash {
self.tree.root()
}
/// Returns the hash algorithm of the tree.
pub fn algorithm(&self) -> HashAlgId {
self.tree.algorithm()
}
/// Generates a proof for the given indices.
///
/// # Arguments
///
/// * `idxs` - The transcript indices to prove.
pub fn proof<'idx>(
&self,
idxs: impl Iterator<Item = &'idx (Direction, RangeSet<usize>)>,
) -> Result<EncodingProof, EncodingTreeError> {
let mut openings = HashMap::new();
for dir_idx in idxs {
let direction = dir_idx.0;
let idx = &dir_idx.1;
let leaf_idx = *self
.idxs
.get_by_right(dir_idx)
.ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
let blinder = self.blinders[leaf_idx].clone();
openings.insert(
leaf_idx,
Opening {
direction,
idx: idx.clone(),
blinder,
},
);
}
let mut indices = openings.keys().copied().collect::<Vec<_>>();
indices.sort();
Ok(EncodingProof {
inclusion_proof: self.tree.proof(&indices),
openings,
})
}
/// Returns whether the tree contains the given transcript index.
pub fn contains(&self, idx: &(Direction, RangeSet<usize>)) -> bool {
self.idxs.contains_right(idx)
}
pub(crate) fn idx(&self, direction: Direction) -> &RangeSet<usize> {
match direction {
Direction::Sent => &self.sent_idx,
Direction::Received => &self.received_idx,
}
}
/// Returns the committed transcript indices.
pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.idxs.right_values()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
fixtures::{encoder_secret, encoding_provider},
hash::{Blake3, HashProvider},
transcript::{encoding::EncodingCommitment, Transcript},
};
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
fn new_tree<'seq>(
transcript: &Transcript,
idxs: impl Iterator<Item = &'seq (Direction, RangeSet<usize>)>,
) -> Result<EncodingTree, EncodingTreeError> {
let provider = encoding_provider(transcript.sent(), transcript.received());
EncodingTree::new(&Blake3::default(), idxs, &provider)
}
#[test]
fn test_encoding_tree() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
assert!(tree.contains(&idx_0));
assert!(tree.contains(&idx_1));
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
let commitment = EncodingCommitment {
root: tree.root(),
secret: encoder_secret(),
};
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap();
assert_eq!(auth_sent, idx_0.1);
assert_eq!(auth_recv, idx_1.1);
}
#[test]
fn test_encoding_tree_multiple_ranges() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..1));
let idx_1 = (Direction::Sent, RangeSet::from(1..POST_JSON.len()));
let idx_2 = (Direction::Received, RangeSet::from(0..1));
let idx_3 = (Direction::Received, RangeSet::from(1..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
assert!(tree.contains(&idx_0));
assert!(tree.contains(&idx_1));
assert!(tree.contains(&idx_2));
assert!(tree.contains(&idx_3));
let proof = tree
.proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
.unwrap();
let commitment = EncodingCommitment {
root: tree.root(),
secret: encoder_secret(),
};
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&commitment,
transcript.sent(),
transcript.received(),
)
.unwrap();
let mut expected_auth_sent = RangeSet::default();
expected_auth_sent.union_mut(&idx_0.1);
expected_auth_sent.union_mut(&idx_1.1);
let mut expected_auth_recv = RangeSet::default();
expected_auth_recv.union_mut(&idx_2.1);
expected_auth_recv.union_mut(&idx_3.1);
assert_eq!(auth_sent, expected_auth_sent);
assert_eq!(auth_recv, expected_auth_recv);
}
#[test]
fn test_encoding_tree_proof_missing_leaf() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..4));
let idx_2 = (Direction::Received, RangeSet::from(4..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
let result = tree
.proof([&idx_0, &idx_1, &idx_2].into_iter())
.unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
}
#[test]
fn test_encoding_tree_out_of_bounds() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len() + 1));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len() + 1));
let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
}
#[test]
fn test_encoding_tree_missing_encoding() {
let provider = encoding_provider(&[], &[]);
let result = EncodingTree::new(
&Blake3::default(),
[(Direction::Sent, RangeSet::from(0..8))].iter(),
&provider,
)
.unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
}
}

View File

@@ -1,10 +1,6 @@
//! Transcript proofs.
use rangeset::{
iter::RangeIterator,
ops::{Cover, Set},
set::ToRangeSet,
};
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt};
@@ -14,6 +10,7 @@ use crate::{
hash::{HashAlgId, HashProvider},
transcript::{
commit::{TranscriptCommitment, TranscriptCommitmentKind},
encoding::{EncodingProof, EncodingProofError, EncodingTree},
hash::{hash_plaintext, PlaintextHash, PlaintextHashSecret},
Direction, PartialTranscript, RangeSet, Transcript, TranscriptSecret,
},
@@ -25,18 +22,14 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3,
},
TranscriptCommitmentKind::Hash {
alg: HashAlgId::KECCAK256,
},
TranscriptCommitmentKind::Encoding,
];
/// Proof of the contents of a transcript.
#[derive(Clone, Serialize, Deserialize)]
pub struct TranscriptProof {
transcript: PartialTranscript,
encoding_proof: Option<EncodingProof>,
hash_secrets: Vec<PlaintextHashSecret>,
}
@@ -50,18 +43,26 @@ impl TranscriptProof {
/// # Arguments
///
/// * `provider` - The hash provider to use for verification.
/// * `length` - The transcript length.
/// * `commitments` - The commitments to verify against.
/// * `attestation_body` - The attestation body to verify against.
pub fn verify_with_provider<'a>(
self,
provider: &HashProvider,
length: &TranscriptLength,
commitments: impl IntoIterator<Item = &'a TranscriptCommitment>,
) -> Result<PartialTranscript, TranscriptProofError> {
let mut encoding_commitment = None;
let mut hash_commitments = HashSet::new();
// Index commitments.
for commitment in commitments {
match commitment {
TranscriptCommitment::Encoding(commitment) => {
if encoding_commitment.replace(commitment).is_some() {
return Err(TranscriptProofError::new(
ErrorKind::Encoding,
"multiple encoding commitments are present.",
));
}
}
TranscriptCommitment::Hash(plaintext_hash) => {
hash_commitments.insert(plaintext_hash);
}
@@ -80,6 +81,26 @@ impl TranscriptProof {
let mut total_auth_sent = RangeSet::default();
let mut total_auth_recv = RangeSet::default();
// Verify encoding proof.
if let Some(proof) = self.encoding_proof {
let commitment = encoding_commitment.ok_or_else(|| {
TranscriptProofError::new(
ErrorKind::Encoding,
"contains an encoding proof but missing encoding commitment",
)
})?;
let (auth_sent, auth_recv) = proof.verify_with_provider(
provider,
commitment,
self.transcript.sent_unsafe(),
self.transcript.received_unsafe(),
)?;
total_auth_sent.union_mut(&auth_sent);
total_auth_recv.union_mut(&auth_recv);
}
let mut buffer = Vec::new();
for PlaintextHashSecret {
direction,
@@ -108,7 +129,7 @@ impl TranscriptProof {
}
buffer.clear();
for range in idx.iter() {
for range in idx.iter_ranges() {
buffer.extend_from_slice(&plaintext[range]);
}
@@ -163,6 +184,7 @@ impl TranscriptProofError {
#[derive(Debug)]
enum ErrorKind {
Encoding,
Hash,
Proof,
}
@@ -172,6 +194,7 @@ impl fmt::Display for TranscriptProofError {
f.write_str("transcript proof error: ")?;
match self.kind {
ErrorKind::Encoding => f.write_str("encoding error")?,
ErrorKind::Hash => f.write_str("hash error")?,
ErrorKind::Proof => f.write_str("proof error")?,
}
@@ -184,6 +207,12 @@ impl fmt::Display for TranscriptProofError {
}
}
impl From<EncodingProofError> for TranscriptProofError {
fn from(e: EncodingProofError) -> Self {
TranscriptProofError::new(ErrorKind::Encoding, e)
}
}
/// Union of ranges to reveal.
#[derive(Clone, Debug, PartialEq)]
struct QueryIdx {
@@ -228,6 +257,7 @@ pub struct TranscriptProofBuilder<'a> {
/// Commitment kinds in order of preference for building transcript proofs.
commitment_kinds: Vec<TranscriptCommitmentKind>,
transcript: &'a Transcript,
encoding_tree: Option<&'a EncodingTree>,
hash_secrets: Vec<&'a PlaintextHashSecret>,
committed_sent: RangeSet<usize>,
committed_recv: RangeSet<usize>,
@@ -243,9 +273,15 @@ impl<'a> TranscriptProofBuilder<'a> {
let mut committed_sent = RangeSet::default();
let mut committed_recv = RangeSet::default();
let mut encoding_tree = None;
let mut hash_secrets = Vec::new();
for secret in secrets {
match secret {
TranscriptSecret::Encoding(tree) => {
committed_sent.union_mut(tree.idx(Direction::Sent));
committed_recv.union_mut(tree.idx(Direction::Received));
encoding_tree = Some(tree);
}
TranscriptSecret::Hash(hash) => {
match hash.direction {
Direction::Sent => committed_sent.union_mut(&hash.idx),
@@ -259,6 +295,7 @@ impl<'a> TranscriptProofBuilder<'a> {
Self {
commitment_kinds: DEFAULT_COMMITMENT_KINDS.to_vec(),
transcript,
encoding_tree,
hash_secrets,
committed_sent,
committed_recv,
@@ -314,7 +351,7 @@ impl<'a> TranscriptProofBuilder<'a> {
if idx.is_subset(committed) {
self.query_idx.union(&direction, &idx);
} else {
let missing = idx.difference(committed).into_set();
let missing = idx.difference(committed);
return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::MissingCommitment,
format!(
@@ -356,6 +393,7 @@ impl<'a> TranscriptProofBuilder<'a> {
transcript: self
.transcript
.to_partial(self.query_idx.sent.clone(), self.query_idx.recv.clone()),
encoding_proof: None,
hash_secrets: Vec::new(),
};
let mut uncovered_query_idx = self.query_idx.clone();
@@ -367,6 +405,46 @@ impl<'a> TranscriptProofBuilder<'a> {
// self.commitment_kinds.
if let Some(kind) = commitment_kinds_iter.next() {
match kind {
TranscriptCommitmentKind::Encoding => {
let Some(encoding_tree) = self.encoding_tree else {
// Proceeds to the next preferred commitment kind if encoding tree is
// not available.
continue;
};
let (sent_dir_idxs, sent_uncovered) = uncovered_query_idx.sent.cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Sent),
|(_, idx)| idx,
);
// Uncovered ranges will be checked with ranges of the next
// preferred commitment kind.
uncovered_query_idx.sent = sent_uncovered;
let (recv_dir_idxs, recv_uncovered) = uncovered_query_idx.recv.cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Received),
|(_, idx)| idx,
);
uncovered_query_idx.recv = recv_uncovered;
let dir_idxs = sent_dir_idxs
.into_iter()
.chain(recv_dir_idxs)
.collect::<Vec<_>>();
// Skip proof generation if there are no committed ranges that can cover the
// query ranges.
if !dir_idxs.is_empty() {
transcript_proof.encoding_proof = Some(
encoding_tree
.proof(dir_idxs.into_iter())
.expect("subsequences were checked to be in tree"),
);
}
}
TranscriptCommitmentKind::Hash { alg } => {
let (sent_hashes, sent_uncovered) = uncovered_query_idx.sent.cover_by(
self.hash_secrets.iter().filter(|hash| {
@@ -489,14 +567,45 @@ impl fmt::Display for TranscriptProofBuilderError {
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rangeset::prelude::*;
use rangeset::RangeSet;
use rstest::rstest;
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
use crate::hash::{Blinder, HashAlgId};
use crate::{
fixtures::encoding_provider,
hash::{Blake3, Blinder, HashAlgId},
transcript::TranscriptCommitConfigBuilder,
};
use super::*;
#[rstest]
fn test_verify_missing_encoding_commitment_root() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let idxs = vec![(Direction::Received, RangeSet::from(0..transcript.len().1))];
let encoding_tree = EncodingTree::new(
&Blake3::default(),
&idxs,
&encoding_provider(transcript.sent(), transcript.received()),
)
.unwrap();
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
builder.reveal_recv(&(0..transcript.len().1)).unwrap();
let transcript_proof = builder.build().unwrap();
let provider = HashProvider::default();
let err = transcript_proof
.verify_with_provider(&provider, &transcript.length(), &[])
.err()
.unwrap();
assert!(matches!(err.kind, ErrorKind::Encoding));
}
#[rstest]
fn test_reveal_range_out_of_bounds() {
let transcript = Transcript::new(
@@ -516,7 +625,7 @@ mod tests {
}
#[rstest]
fn test_reveal_missing_commitment() {
fn test_reveal_missing_encoding_tree() {
let transcript = Transcript::new(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
@@ -528,10 +637,7 @@ mod tests {
}
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
#[case::keccak256(HashAlgId::KECCAK256)]
fn test_reveal_with_hash_commitment(#[case] alg: HashAlgId) {
fn test_reveal_with_hash_commitment() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
@@ -539,6 +645,7 @@ mod tests {
let direction = Direction::Sent;
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
let commitment = PlaintextHash {
@@ -576,10 +683,7 @@ mod tests {
}
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
#[case::keccak256(HashAlgId::KECCAK256)]
fn test_reveal_with_inconsistent_hash_commitment(#[case] alg: HashAlgId) {
fn test_reveal_with_inconsistent_hash_commitment() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
@@ -587,6 +691,7 @@ mod tests {
let direction = Direction::Sent;
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
let commitment = PlaintextHash {
@@ -629,19 +734,24 @@ mod tests {
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Encoding,
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Encoding,
]);
assert_eq!(
builder.commitment_kinds,
vec![TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256
},]
vec![
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256
},
TranscriptCommitmentKind::Encoding
]
);
}
@@ -651,7 +761,7 @@ mod tests {
RangeSet::from([0..10, 12..30]),
true,
)]
#[case::reveal_all_rangesets_with_single_superset_range(
#[case::reveal_all_rangesets_with_superset_ranges(
vec![RangeSet::from([0..1]), RangeSet::from([1..2, 8..9]), RangeSet::from([2..4, 6..8]), RangeSet::from([2..3, 6..7]), RangeSet::from([9..12])],
RangeSet::from([0..4, 6..9]),
true,
@@ -682,30 +792,29 @@ mod tests {
false,
)]
#[allow(clippy::single_range_in_vec_init)]
fn test_reveal_multiple_rangesets_with_one_rangeset(
fn test_reveal_mutliple_rangesets_with_one_rangeset(
#[case] commit_recv_rangesets: Vec<RangeSet<usize>>,
#[case] reveal_recv_rangeset: RangeSet<usize>,
#[case] success: bool,
) {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
// Create hash commitments for each rangeset
let mut secrets = Vec::new();
// Encoding commitment kind
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
for rangeset in commit_recv_rangesets.iter() {
let blinder: crate::hash::Blinder = rng.random();
let secret = PlaintextHashSecret {
direction: Direction::Received,
idx: rangeset.clone(),
alg: HashAlgId::BLAKE3,
blinder,
};
secrets.push(TranscriptSecret::Hash(secret));
transcript_commitment_builder.commit_recv(rangeset).unwrap();
}
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let encoding_tree = EncodingTree::new(
&Blake3::default(),
transcripts_commitment_config.iter_encoding(),
&encoding_provider(GET_WITH_HEADER, OK_JSON),
)
.unwrap();
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
if success {
@@ -758,34 +867,27 @@ mod tests {
#[case] uncovered_sent_rangeset: RangeSet<usize>,
#[case] uncovered_recv_rangeset: RangeSet<usize>,
) {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
// Create hash commitments for each rangeset
let mut secrets = Vec::new();
// Encoding commitment kind
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
for rangeset in commit_sent_rangesets.iter() {
let blinder: crate::hash::Blinder = rng.random();
let secret = PlaintextHashSecret {
direction: Direction::Sent,
idx: rangeset.clone(),
alg: HashAlgId::BLAKE3,
blinder,
};
secrets.push(TranscriptSecret::Hash(secret));
transcript_commitment_builder.commit_sent(rangeset).unwrap();
}
for rangeset in commit_recv_rangesets.iter() {
let blinder: crate::hash::Blinder = rng.random();
let secret = PlaintextHashSecret {
direction: Direction::Received,
idx: rangeset.clone(),
alg: HashAlgId::BLAKE3,
blinder,
};
secrets.push(TranscriptSecret::Hash(secret));
transcript_commitment_builder.commit_recv(rangeset).unwrap();
}
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
let encoding_tree = EncodingTree::new(
&Blake3::default(),
transcripts_commitment_config.iter_encoding(),
&encoding_provider(GET_WITH_HEADER, OK_JSON),
)
.unwrap();
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
builder.reveal_sent(&reveal_sent_rangeset).unwrap();
builder.reveal_recv(&reveal_recv_rangeset).unwrap();

View File

@@ -10,53 +10,10 @@ use crate::{
use tls_core::msgs::{
alert::AlertMessagePayload,
codec::{Codec, Reader},
enums::{AlertDescription, ProtocolVersion},
enums::{AlertDescription, ContentType, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload},
};
/// TLS record content type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ContentType {
/// Change cipher spec protocol.
ChangeCipherSpec,
/// Alert protocol.
Alert,
/// Handshake protocol.
Handshake,
/// Application data protocol.
ApplicationData,
/// Heartbeat protocol.
Heartbeat,
/// Unknown protocol.
Unknown(u8),
}
impl From<ContentType> for tls_core::msgs::enums::ContentType {
fn from(content_type: ContentType) -> Self {
match content_type {
ContentType::ChangeCipherSpec => tls_core::msgs::enums::ContentType::ChangeCipherSpec,
ContentType::Alert => tls_core::msgs::enums::ContentType::Alert,
ContentType::Handshake => tls_core::msgs::enums::ContentType::Handshake,
ContentType::ApplicationData => tls_core::msgs::enums::ContentType::ApplicationData,
ContentType::Heartbeat => tls_core::msgs::enums::ContentType::Heartbeat,
ContentType::Unknown(id) => tls_core::msgs::enums::ContentType::Unknown(id),
}
}
}
impl From<tls_core::msgs::enums::ContentType> for ContentType {
fn from(content_type: tls_core::msgs::enums::ContentType) -> Self {
match content_type {
tls_core::msgs::enums::ContentType::ChangeCipherSpec => ContentType::ChangeCipherSpec,
tls_core::msgs::enums::ContentType::Alert => ContentType::Alert,
tls_core::msgs::enums::ContentType::Handshake => ContentType::Handshake,
tls_core::msgs::enums::ContentType::ApplicationData => ContentType::ApplicationData,
tls_core::msgs::enums::ContentType::Heartbeat => ContentType::Heartbeat,
tls_core::msgs::enums::ContentType::Unknown(id) => ContentType::Unknown(id),
}
}
}
/// A transcript of TLS records sent and received by the prover.
#[derive(Debug, Clone)]
pub struct TlsTranscript {

View File

@@ -53,21 +53,6 @@ impl RootCertStore {
pub fn empty() -> Self {
Self { roots: Vec::new() }
}
/// Creates a root certificate store with Mozilla root certificates.
///
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
/// application binary can be recompiled and deployed on-demand in the case
/// that the root certificates need to be updated.
#[cfg(feature = "mozilla-certs")]
pub fn mozilla() -> Self {
Self {
roots: webpki_root_certs::TLS_SERVER_ROOT_CERTS
.iter()
.map(|cert| CertificateDer(cert.to_vec()))
.collect(),
}
}
}
/// Server certificate verifier.
@@ -97,12 +82,8 @@ impl ServerCertVerifier {
Ok(Self { roots })
}
/// Creates a server certificate verifier with Mozilla root certificates.
///
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
/// application binary can be recompiled and deployed on-demand in the case
/// that the root certificates need to be updated.
#[cfg(feature = "mozilla-certs")]
/// Creates a new server certificate verifier with Mozilla root
/// certificates.
pub fn mozilla() -> Self {
Self {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),

View File

@@ -15,7 +15,6 @@ tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
spansy = { workspace = true }
anyhow = { workspace = true }
bincode = { workspace = true }
chrono = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
@@ -38,9 +37,7 @@ tokio = { workspace = true, features = [
tokio-util = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = [
"barretenberg",
] }
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = ["barretenberg"] }
[[example]]
name = "interactive"

View File

@@ -4,7 +4,5 @@ This folder contains examples demonstrating how to use the TLSNotary protocol.
* [Interactive](./interactive/README.md): Interactive Prover and Verifier session without a trusted notary.
* [Attestation](./attestation/README.md): Performing a simple notarization with a trusted notary.
* [Interactive_zk](./interactive_zk/README.md): Interactive Prover and Verifier session demonstrating zero-knowledge age verification using Noir.
Refer to <https://tlsnotary.org/docs/quick_start> for a quick start guide to using TLSNotary with these examples.

View File

@@ -4,7 +4,6 @@
use std::env;
use anyhow::{anyhow, Result};
use clap::Parser;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode};
@@ -24,17 +23,12 @@ use tlsn::{
Attestation, AttestationConfig, CryptoProvider, Secrets,
},
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
verifier::VerifierConfig,
CertificateDer, PrivateKeyDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore,
},
connection::{ConnectionInfo, HandshakeData, ServerName, TranscriptLength},
prover::{state::Committed, Prover, ProverOutput},
prover::{state::Committed, ProveConfig, Prover, ProverConfig, ProverOutput, TlsConfig},
transcript::{ContentType, TranscriptCommitConfig},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_examples::ExampleType;
use tlsn_formats::http::{DefaultHttpCommitter, HttpCommit, HttpTranscript};
@@ -53,7 +47,7 @@ struct Args {
}
#[tokio::main]
async fn main() -> Result<()> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let args = Args::parse();
@@ -93,63 +87,64 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
uri: &str,
extra_headers: Vec<(&str, &str)>,
example_type: &ExampleType,
) -> Result<()> {
) -> Result<(), Box<dyn std::error::Error>> {
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(DEFAULT_FIXTURE_PORT);
// Create a new prover and perform necessary setup.
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
// Select the TLS commitment protocol.
.protocol(
MpcTlsConfig::builder()
// We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
)
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
));
let tls_config = tls_config_builder.build().unwrap();
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
// We must configure the amount of data we expect to exchange beforehand, which will
// be preprocessed prior to the connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
socket.compat(),
)
.await?;
);
let prover_config = prover_config_builder.build()?;
// Create a new prover and perform necessary setup.
let prover = Prover::new(prover_config).setup(socket.compat()).await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?;
// Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
))
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// The returned `mpc_tls_connection` is an MPC TLS connection to the server: all
// data written to/read from it will be encrypted/decrypted using MPC with
// the notary.
let (mpc_tls_connection, prover_fut) = prover.connect(client_socket.compat()).await?;
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Spawn the prover task to be run concurrently in the background.
let prover_task = tokio::spawn(prover_fut);
// Attach the hyper HTTP client to the connection.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(tls_connection).await?;
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
// Spawn the HTTP task to be run concurrently in the background.
tokio::spawn(connection);
@@ -170,7 +165,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
}
let request = request_builder.body(Empty::<Bytes>::new())?;
info!("Starting connection with the server");
info!("Starting an MPC TLS connection with the server");
// Send the request to the server and wait for the response.
let response = request_sender.send_request(request).await?;
@@ -180,7 +175,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
assert!(response.status() == StatusCode::OK);
// The prover task should be done now, so we can await it.
let prover = prover_task.await??;
let mut prover = prover_task.await??;
// Parse the HTTP transcript.
let transcript = HttpTranscript::parse(prover.transcript())?;
@@ -222,7 +217,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let request_config = builder.build()?;
let (attestation, secrets) = notarize(prover, &request_config, req_tx, resp_rx).await?;
let (attestation, secrets) = notarize(&mut prover, &request_config, req_tx, resp_rx).await?;
// Write the attestation to disk.
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
@@ -243,11 +238,11 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
}
async fn notarize(
mut prover: Prover<Committed>,
prover: &mut Prover<Committed>,
config: &RequestConfig,
request_tx: Sender<AttestationRequest>,
attestation_rx: Receiver<Attestation>,
) -> Result<(Attestation, Secrets)> {
) -> Result<(Attestation, Secrets), Box<dyn std::error::Error>> {
let mut builder = ProveConfig::builder(prover.transcript());
if let Some(config) = config.transcript_commit() {
@@ -262,27 +257,25 @@ async fn notarize(
..
} = prover.prove(&disclosure_config).await?;
let transcript = prover.transcript().clone();
let tls_transcript = prover.tls_transcript().clone();
prover.close().await?;
// Build an attestation request.
let mut builder = AttestationRequest::builder(config);
builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.handshake_data(HandshakeData {
certs: tls_transcript
certs: prover
.tls_transcript()
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
sig: prover
.tls_transcript()
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
binding: prover.tls_transcript().certificate_binding().clone(),
})
.transcript(transcript)
.transcript(prover.transcript().clone())
.transcript_commitments(transcript_secrets, transcript_commitments);
let (request, secrets) = builder.build(&CryptoProvider::default())?;
@@ -290,18 +283,15 @@ async fn notarize(
// Send attestation request to notary.
request_tx
.send(request.clone())
.map_err(|_| anyhow!("notary is not receiving attestation request"))?;
.map_err(|_| "notary is not receiving attestation request".to_string())?;
// Receive attestation from notary.
let attestation = attestation_rx
.await
.map_err(|err| anyhow!("notary did not respond with attestation: {err}"))?;
// Signature verifier for the signature algorithm in the request.
let provider = CryptoProvider::default();
.map_err(|err| format!("notary did not respond with attestation: {err}"))?;
// Check the attestation is consistent with the Prover's view.
request.validate(&attestation, &provider)?;
request.validate(&attestation)?;
Ok((attestation, secrets))
}
@@ -310,7 +300,14 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: S,
request_rx: Receiver<AttestationRequest>,
attestation_tx: Sender<Attestation>,
) -> Result<()> {
) -> Result<(), Box<dyn std::error::Error>> {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()
.unwrap();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
@@ -318,24 +315,20 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap();
let verifier = Verifier::new(verifier_config)
.commit(socket.compat())
.await?
.accept()
let mut verifier = Verifier::new(verifier_config)
.setup(socket.compat())
.await?
.run()
.await?;
let (
VerifierOutput {
transcript_commitments,
..
},
verifier,
) = verifier.verify().await?.accept().await?;
let VerifierOutput {
transcript_commitments,
..
} = verifier.verify(&VerifyConfig::default()).await?;
let tls_transcript = verifier.tls_transcript().clone();
@@ -397,7 +390,7 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
// Send attestation to prover.
attestation_tx
.send(attestation)
.map_err(|_| anyhow!("prover is not receiving attestation"))?;
.map_err(|_| "prover is not receiving attestation".to_string())?;
Ok(())
}

View File

@@ -12,8 +12,8 @@ use tlsn::{
signing::VerifyingKey,
CryptoProvider,
},
config::{CertificateDer, RootCertStore},
verifier::ServerCertVerifier,
webpki::{CertificateDer, RootCertStore},
};
use tlsn_examples::ExampleType;
use tlsn_server_fixture_certs::CA_CERT_DER;

View File

@@ -3,7 +3,6 @@ use std::{
net::{IpAddr, SocketAddr},
};
use anyhow::Result;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
@@ -12,18 +11,11 @@ use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
use tlsn::{
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig, TlsCommitProtocolConfig},
verifier::VerifierConfig,
},
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
prover::Prover,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::PartialTranscript,
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -54,7 +46,7 @@ async fn main() {
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
let prover = prover(prover_socket, &server_addr, &uri);
let verifier = verifier(verifier_socket);
let (_, transcript) = tokio::try_join!(prover, verifier).unwrap();
let (_, transcript) = tokio::join!(prover, verifier);
println!("Successfully verified {}", &uri);
println!(
@@ -72,57 +64,61 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
server_addr: &SocketAddr,
uri: &str,
) -> Result<()> {
) {
let uri = uri.parse::<Uri>().unwrap();
assert_eq!(uri.scheme().unwrap().as_str(), "https");
let server_domain = uri.authority().unwrap().host();
// Create a new prover and perform necessary setup.
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
// Select the TLS commitment protocol.
.protocol(
MpcTlsConfig::builder()
// We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
)
.build()?,
verifier_socket.compat(),
)
.await?;
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(server_domain.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap(),
);
// Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
let prover_config = prover_config_builder.build().unwrap();
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
.await
.unwrap();
// Connect to TLS Server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await.unwrap();
// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) =
prover.connect(tls_client_socket.compat()).await.unwrap();
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(tls_connection).await?;
hyper::client::conn::http1::handshake(mpc_tls_connection)
.await
.unwrap();
// Spawn the connection to run in the background.
tokio::spawn(connection);
@@ -134,13 +130,14 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.header("Connection", "close")
.header("Secret", SECRET)
.method("GET")
.body(Empty::<Bytes>::new())?;
let response = request_sender.send_request(request).await?;
.body(Empty::<Bytes>::new())
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
assert!(response.status() == StatusCode::OK);
// Create proof for the Verifier.
let mut prover = prover_task.await??;
let mut prover = prover_task.await.unwrap().unwrap();
let mut builder = ProveConfig::builder(prover.transcript());
@@ -156,8 +153,10 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.expect("the secret should be in the sent data");
// Reveal everything except for the secret.
builder.reveal_sent(&(0..pos))?;
builder.reveal_sent(&(pos + SECRET.len()..prover.transcript().sent().len()))?;
builder.reveal_sent(&(0..pos)).unwrap();
builder
.reveal_sent(&(pos + SECRET.len()..prover.transcript().sent().len()))
.unwrap();
// Find the substring "Dick".
let pos = prover
@@ -168,21 +167,28 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.expect("the substring 'Dick' should be in the received data");
// Reveal everything except for the substring.
builder.reveal_recv(&(0..pos))?;
builder.reveal_recv(&(pos + 4..prover.transcript().received().len()))?;
builder.reveal_recv(&(0..pos)).unwrap();
builder
.reveal_recv(&(pos + 4..prover.transcript().received().len()))
.unwrap();
let config = builder.build()?;
let config = builder.build().unwrap();
prover.prove(&config).await?;
prover.close().await?;
Ok(())
prover.prove(&config).await.unwrap();
prover.close().await.unwrap();
}
#[instrument(skip(socket))]
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> Result<PartialTranscript> {
) -> PartialTranscript {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
@@ -190,56 +196,20 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?;
.protocol_config_validator(config_validator)
.build()
.unwrap();
let verifier = Verifier::new(verifier_config);
// Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(socket.compat()).await?;
// This is the opportunity to ensure the prover does not attempt to overload the
// verifier.
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
{
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
Some("max_sent_data is too large")
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
Some("max_recv_data is too large")
} else {
None
}
} else {
Some("expecting to use MPC-TLS")
};
if reject.is_some() {
verifier.reject(reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected"));
}
// Runs the TLS commitment protocol to completion.
let verifier = verifier.accept().await?.run().await?;
// Validate the proving request and then verify.
let verifier = verifier.verify().await?;
if !verifier.request().server_identity() {
let verifier = verifier
.reject(Some("expecting to verify the server name"))
.await?;
verifier.close().await?;
return Err(anyhow::anyhow!("prover did not reveal the server name"));
}
let (
VerifierOutput {
server_name,
transcript,
..
},
verifier,
) = verifier.accept().await?;
verifier.close().await?;
// Receive authenticated data.
let VerifierOutput {
server_name,
transcript,
..
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await
.unwrap();
let server_name = server_name.expect("prover should have revealed server name");
let transcript = transcript.expect("prover should have revealed transcript data");
@@ -262,7 +232,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let ServerName::Dns(server_name) = server_name;
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
Ok(transcript)
transcript
}
/// Render redacted bytes as `🙈`.

View File

@@ -2,7 +2,6 @@ mod prover;
mod types;
mod verifier;
use anyhow::Result;
use prover::prover;
use std::{
env,
@@ -13,7 +12,7 @@ use tlsn_server_fixture_certs::SERVER_DOMAIN;
use verifier::verifier;
#[tokio::main]
async fn main() -> Result<()> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
@@ -26,7 +25,7 @@ async fn main() -> Result<()> {
let uri = format!("https://{SERVER_DOMAIN}:{server_port}/elster");
let server_ip: IpAddr = server_host
.parse()
.map_err(|e| anyhow::anyhow!("Invalid IP address '{server_host}': {e}"))?;
.map_err(|e| format!("Invalid IP address '{}': {}", server_host, e))?;
let server_addr = SocketAddr::from((server_ip, server_port));
// Connect prover and verifier.

View File

@@ -4,7 +4,6 @@ use crate::types::received_commitments;
use super::types::ZKProofBundle;
use anyhow::Result;
use chrono::{Datelike, Local, NaiveDate};
use http_body_util::Empty;
use hyper::{body::Bytes, header, Request, StatusCode, Uri};
@@ -22,27 +21,24 @@ use spansy::{
http::{BodyContent, Requests, Responses},
Spanned,
};
use tls_server_fixture::{CA_CERT_DER, SERVER_DOMAIN};
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{
prove::{ProveConfig, ProveConfigBuilder},
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
},
config::{CertificateDer, ProtocolConfig, RootCertStore},
connection::ServerName,
hash::HashAlgId,
prover::Prover,
prover::{ProveConfig, ProveConfigBuilder, Prover, ProverConfig, TlsConfig},
transcript::{
hash::{PlaintextHash, PlaintextHashSecret},
Direction, TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitmentKind,
TranscriptSecret,
},
webpki::{CertificateDer, RootCertStore},
};
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tlsn_examples::MAX_RECV_DATA;
use tokio::io::AsyncWriteExt;
use tlsn_examples::MAX_SENT_DATA;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
@@ -52,64 +48,60 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
mut verifier_extra_socket: T,
server_addr: &SocketAddr,
uri: &str,
) -> Result<()> {
) -> Result<(), Box<dyn std::error::Error>> {
let uri = uri.parse::<Uri>()?;
if uri.scheme().map(|s| s.as_str()) != Some("https") {
return Err(anyhow::anyhow!("URI must use HTTPS scheme"));
return Err("URI must use HTTPS scheme".into());
}
let server_domain = uri
.authority()
.ok_or_else(|| anyhow::anyhow!("URI must have authority"))?
.host();
let server_domain = uri.authority().ok_or("URI must have authority")?.host();
// Create a new prover and perform necessary setup.
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
// Select the TLS commitment protocol.
.protocol(
MpcTlsConfig::builder()
// We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// connection. Reducing these limits will improve
// performance.
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?,
)
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(server_domain.try_into()?))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?,
verifier_socket.compat(),
)
);
let prover_config = prover_config_builder.build()?;
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
.await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Connect to TLS Server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await?;
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(tls_connection).await?;
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
// Spawn the connection to run in the background.
tokio::spawn(connection);
@@ -126,10 +118,7 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let response = request_sender.send_request(request).await?;
if response.status() != StatusCode::OK {
return Err(anyhow::anyhow!(
"MPC-TLS request failed with status {}",
response.status()
));
return Err(format!("MPC-TLS request failed with status {}", response.status()).into());
}
// Create proof for the Verifier.
@@ -174,11 +163,11 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let received_commitments = received_commitments(&prover_output.transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or_else(|| anyhow::anyhow!("No received commitments found"))?; // committed hash (of date of birth string)
.ok_or("No received commitments found")?; // committed hash (of date of birth string)
let received_secrets = received_secrets(&prover_output.transcript_secrets);
let received_secret = received_secrets
.first()
.ok_or_else(|| anyhow::anyhow!("No received secrets found"))?; // hash blinder
.ok_or("No received secrets found")?; // hash blinder
let proof_input = prepare_zk_proof_input(received, received_commitment, received_secret)?;
let proof_bundle = generate_zk_proof(&proof_input)?;
@@ -191,30 +180,28 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
}
// Reveal everything from the request, except for the authorization token.
fn reveal_request(request: &[u8], builder: &mut ProveConfigBuilder<'_>) -> Result<()> {
fn reveal_request(
request: &[u8],
builder: &mut ProveConfigBuilder<'_>,
) -> Result<(), Box<dyn std::error::Error>> {
let reqs = Requests::new_from_slice(request).collect::<Result<Vec<_>, _>>()?;
let req = reqs
.first()
.ok_or_else(|| anyhow::anyhow!("No requests found"))?;
let req = reqs.first().ok_or("No requests found")?;
if req.request.method.as_str() != "GET" {
return Err(anyhow::anyhow!(
"Expected GET method, found {}",
req.request.method.as_str()
));
return Err(format!("Expected GET method, found {}", req.request.method.as_str()).into());
}
let authorization_header = req
.headers_with_name(header::AUTHORIZATION.as_str())
.next()
.ok_or_else(|| anyhow::anyhow!("Authorization header not found"))?;
.ok_or("Authorization header not found")?;
let start_pos = authorization_header
.span()
.indices()
.min()
.ok_or_else(|| anyhow::anyhow!("Could not find authorization header start position"))?
.ok_or("Could not find authorization header start position")?
+ header::AUTHORIZATION.as_str().len()
+ 2;
let end_pos =
@@ -230,43 +217,38 @@ fn reveal_received(
received: &[u8],
builder: &mut ProveConfigBuilder<'_>,
transcript_commitment_builder: &mut TranscriptCommitConfigBuilder,
) -> Result<()> {
) -> Result<(), Box<dyn std::error::Error>> {
let resp = Responses::new_from_slice(received).collect::<Result<Vec<_>, _>>()?;
let response = resp
.first()
.ok_or_else(|| anyhow::anyhow!("No responses found"))?;
let body = response
.body
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Response body not found"))?;
let response = resp.first().ok_or("No responses found")?;
let body = response.body.as_ref().ok_or("Response body not found")?;
let BodyContent::Json(json) = &body.content else {
return Err(anyhow::anyhow!("Expected JSON body content"));
return Err("Expected JSON body content".into());
};
// reveal tax year
let tax_year = json
.get("tax_year")
.ok_or_else(|| anyhow::anyhow!("tax_year field not found in JSON"))?;
.ok_or("tax_year field not found in JSON")?;
let start_pos = tax_year
.span()
.indices()
.min()
.ok_or_else(|| anyhow::anyhow!("Could not find tax_year start position"))?
.ok_or("Could not find tax_year start position")?
- 11;
let end_pos = tax_year
.span()
.indices()
.max()
.ok_or_else(|| anyhow::anyhow!("Could not find tax_year end position"))?
.ok_or("Could not find tax_year end position")?
+ 1;
builder.reveal_recv(&(start_pos..end_pos))?;
// commit to hash of date of birth
let dob = json
.get("taxpayer.date_of_birth")
.ok_or_else(|| anyhow::anyhow!("taxpayer.date_of_birth field not found in JSON"))?;
.ok_or("taxpayer.date_of_birth field not found in JSON")?;
transcript_commitment_builder.commit_recv(dob.span())?;
@@ -297,7 +279,7 @@ fn prepare_zk_proof_input(
received: &[u8],
received_commitment: &PlaintextHash,
received_secret: &PlaintextHashSecret,
) -> Result<ZKProofInput> {
) -> Result<ZKProofInput, Box<dyn std::error::Error>> {
assert_eq!(received_commitment.direction, Direction::Received);
assert_eq!(received_commitment.hash.alg, HashAlgId::SHA256);
@@ -306,11 +288,11 @@ fn prepare_zk_proof_input(
let dob_start = received_commitment
.idx
.min()
.ok_or_else(|| anyhow::anyhow!("No start index for DOB"))?;
.ok_or("No start index for DOB")?;
let dob_end = received_commitment
.idx
.end()
.ok_or_else(|| anyhow::anyhow!("No end index for DOB"))?;
.ok_or("No end index for DOB")?;
let dob = received[dob_start..dob_end].to_vec();
let blinder = received_secret.blinder.as_bytes().to_vec();
let committed_hash = hash.value.as_bytes().to_vec();
@@ -324,10 +306,8 @@ fn prepare_zk_proof_input(
hasher.update(&blinder);
let computed_hash = hasher.finalize();
if committed_hash != computed_hash.as_ref() as &[u8] {
return Err(anyhow::anyhow!(
"Computed hash does not match committed hash"
));
if committed_hash != computed_hash.as_slice() {
return Err("Computed hash does not match committed hash".into());
}
Ok(ZKProofInput {
@@ -338,7 +318,9 @@ fn prepare_zk_proof_input(
})
}
fn generate_zk_proof(proof_input: &ZKProofInput) -> Result<ZKProofBundle> {
fn generate_zk_proof(
proof_input: &ZKProofInput,
) -> Result<ZKProofBundle, Box<dyn std::error::Error>> {
tracing::info!("🔒 Generating ZK proof with Noir...");
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
@@ -347,7 +329,7 @@ fn generate_zk_proof(proof_input: &ZKProofInput) -> Result<ZKProofBundle> {
let json: Value = serde_json::from_str(PROGRAM_JSON)?;
let bytecode = json["bytecode"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("bytecode field not found in program.json"))?;
.ok_or("bytecode field not found in program.json")?;
let mut inputs: Vec<String> = vec![];
inputs.push(proof_input.proof_date.day().to_string());
@@ -372,17 +354,16 @@ fn generate_zk_proof(proof_input: &ZKProofInput) -> Result<ZKProofBundle> {
tracing::debug!("Witness inputs {:?}", inputs);
let input_refs: Vec<&str> = inputs.iter().map(String::as_str).collect();
let witness = from_vec_str_to_witness_map(input_refs).map_err(|e| anyhow::anyhow!(e))?;
let witness = from_vec_str_to_witness_map(input_refs)?;
// Setup SRS
setup_srs_from_bytecode(bytecode, None, false).map_err(|e| anyhow::anyhow!(e))?;
setup_srs_from_bytecode(bytecode, None, false)?;
// Verification key
let vk = get_ultra_honk_verification_key(bytecode, false).map_err(|e| anyhow::anyhow!(e))?;
let vk = get_ultra_honk_verification_key(bytecode, false)?;
// Generate proof
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)
.map_err(|e| anyhow::anyhow!(e))?;
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)?;
tracing::info!("✅ Proof generated ({} bytes)", proof.len());
let proof_bundle = ZKProofBundle { vk, proof };

View File

@@ -1,18 +1,16 @@
use crate::types::received_commitments;
use super::types::ZKProofBundle;
use anyhow::Result;
use chrono::{Local, NaiveDate};
use noir::barretenberg::verify::{get_ultra_honk_verification_key, verify_ultra_honk};
use serde_json::Value;
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{tls_commit::TlsCommitProtocolConfig, verifier::VerifierConfig},
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
hash::HashAlgId,
transcript::{Direction, PartialTranscript},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tlsn_server_fixture_certs::SERVER_DOMAIN;
@@ -24,91 +22,56 @@ use tracing::instrument;
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
mut extra_socket: T,
) -> Result<PartialTranscript> {
let verifier = Verifier::new(
VerifierConfig::builder()
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
);
) -> Result<PartialTranscript, Box<dyn std::error::Error>> {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?;
// Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(socket.compat()).await?;
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let verifier_config = VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()?;
// This is the opportunity to ensure the prover does not attempt to overload the
// verifier.
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
{
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
Some("max_sent_data is too large")
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
Some("max_recv_data is too large")
} else {
None
}
} else {
Some("expecting to use MPC-TLS")
};
let verifier = Verifier::new(verifier_config);
if reject.is_some() {
verifier.reject(reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected"));
}
// Receive authenticated data.
let VerifierOutput {
server_name,
transcript,
transcript_commitments,
..
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await?;
// Runs the TLS commitment protocol to completion.
let verifier = verifier.accept().await?.run().await?;
// Validate the proving request and then verify.
let verifier = verifier.verify().await?;
let request = verifier.request();
if !request.server_identity() || request.reveal().is_none() {
let verifier = verifier
.reject(Some(
"expecting to verify the server name and transcript data",
))
.await?;
verifier.close().await?;
return Err(anyhow::anyhow!(
"prover did not reveal the server name and transcript data"
));
}
let (
VerifierOutput {
server_name,
transcript,
transcript_commitments,
..
},
verifier,
) = verifier.accept().await?;
verifier.close().await?;
let server_name = server_name.expect("server name should be present");
let transcript = transcript.expect("transcript should be present");
let server_name = server_name.ok_or("Prover should have revealed server name")?;
let transcript = transcript.ok_or("Prover should have revealed transcript data")?;
// Create hash commitment for the date of birth field from the response
let sent = transcript.sent_unsafe().to_vec();
let sent_data = String::from_utf8(sent.clone())
.map_err(|e| anyhow::anyhow!("Verifier expected valid UTF-8 sent data: {e}"))?;
.map_err(|e| format!("Verifier expected valid UTF-8 sent data: {}", e))?;
if !sent_data.contains(SERVER_DOMAIN) {
return Err(anyhow::anyhow!(
"Verification failed: Expected host {SERVER_DOMAIN} not found in sent data"
));
return Err(format!(
"Verification failed: Expected host {} not found in sent data",
SERVER_DOMAIN
)
.into());
}
// Check received data.
let received_commitments = received_commitments(&transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or_else(|| anyhow::anyhow!("Missing hash commitment"))?;
.ok_or("Missing received hash commitment")?;
assert!(received_commitment.direction == Direction::Received);
assert!(received_commitment.hash.alg == HashAlgId::SHA256);
@@ -118,10 +81,12 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// Check Session info: server name.
let ServerName::Dns(server_name) = server_name;
if server_name.as_str() != SERVER_DOMAIN {
return Err(anyhow::anyhow!(
"Server name mismatch: expected {SERVER_DOMAIN}, got {}",
return Err(format!(
"Server name mismatch: expected {}, got {}",
SERVER_DOMAIN,
server_name.as_str()
));
)
.into());
}
// Receive ZKProof information from prover
@@ -129,28 +94,26 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
extra_socket.read_to_end(&mut buf).await?;
if buf.is_empty() {
return Err(anyhow::anyhow!("No ZK proof data received from prover"));
return Err("No ZK proof data received from prover".into());
}
let msg: ZKProofBundle = bincode::deserialize(&buf)
.map_err(|e| anyhow::anyhow!("Failed to deserialize ZK proof bundle: {e}"))?;
.map_err(|e| format!("Failed to deserialize ZK proof bundle: {}", e))?;
// Verify zk proof
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
let json: Value = serde_json::from_str(PROGRAM_JSON)
.map_err(|e| anyhow::anyhow!("Failed to parse Noir circuit: {e}"))?;
.map_err(|e| format!("Failed to parse Noir circuit: {}", e))?;
let bytecode = json["bytecode"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Bytecode field missing in noir.json"))?;
.ok_or("Bytecode field missing in noir.json")?;
let vk = get_ultra_honk_verification_key(bytecode, false)
.map_err(|e| anyhow::anyhow!("Failed to get verification key: {e}"))?;
.map_err(|e| format!("Failed to get verification key: {}", e))?;
if vk != msg.vk {
return Err(anyhow::anyhow!(
"Verification key mismatch between computed and provided by prover"
));
return Err("Verification key mismatch between computed and provided by prover".into());
}
let proof = msg.proof.clone();
@@ -162,10 +125,12 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// * and 32*32 bytes for the hash
let min_bytes = (32 + 3) * 32;
if proof.len() < min_bytes {
return Err(anyhow::anyhow!(
"Proof too short: expected at least {min_bytes} bytes, got {}",
return Err(format!(
"Proof too short: expected at least {} bytes, got {}",
min_bytes,
proof.len()
));
)
.into());
}
// Check that the proof date is correctly included in the proof
@@ -174,12 +139,14 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
let proof_date_year: i32 = i32::from_be_bytes(proof[92..96].try_into()?);
let proof_date_from_proof =
NaiveDate::from_ymd_opt(proof_date_year, proof_date_month, proof_date_day)
.ok_or_else(|| anyhow::anyhow!("Invalid proof date in proof"))?;
.ok_or("Invalid proof date in proof")?;
let today = Local::now().date_naive();
if (today - proof_date_from_proof).num_days() < 0 {
return Err(anyhow::anyhow!(
"The proof date can only be today or in the past: provided {proof_date_from_proof}, today {today}"
));
return Err(format!(
"The proof date can only be today or in the past: provided {}, today {}",
proof_date_from_proof, today
)
.into());
}
// Check that the committed hash in the proof matches the hash from the
@@ -197,9 +164,7 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
hex::encode(&committed_hash_in_proof),
hex::encode(&expected_hash)
);
return Err(anyhow::anyhow!(
"Hash in proof does not match committed hash in MPC-TLS"
));
return Err("Hash in proof does not match committed hash in MPC-TLS".into());
}
tracing::info!(
"✅ The hash in the proof matches the committed hash in MPC-TLS ({})",
@@ -208,10 +173,10 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// Finally verify the proof
let is_valid = verify_ultra_honk(msg.proof, msg.vk)
.map_err(|e| anyhow::anyhow!("ZKProof Verification failed: {e}"))?;
.map_err(|e| format!("ZKProof Verification failed: {}", e))?;
if !is_valid {
tracing::error!("❌ Age verification ZKProof failed to verify");
return Err(anyhow::anyhow!("Age verification ZKProof failed to verify"));
return Err("Age verification ZKProof failed to verify".into());
}
tracing::info!("✅ Age verification ZKProof successfully verified");

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-formats"
version = "0.1.0-alpha.14-pre"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -1,59 +1,51 @@
#### Default Representative Benchmarks ####
#
# This benchmark measures TLSNotary performance on three representative network scenarios.
# Each scenario is run multiple times to produce statistical metrics (median, std dev, etc.)
# rather than plots. Use this for quick performance checks and CI regression testing.
#
# Payload sizes:
# - upload-size: 1KB (typical HTTP request)
# - download-size: 2KB (typical HTTP response/API data)
#
# Network scenarios are chosen to represent real-world user conditions where
# TLSNotary is primarily bottlenecked by upload bandwidth.
#### Cable/DSL Home Internet ####
# Most common residential internet connection
# - Asymmetric: high download, limited upload (typical bottleneck)
# - Upload bandwidth: 20 Mbps (realistic cable/DSL upload speed)
# - Latency: 20ms (typical ISP latency)
#### Latency ####
[[group]]
name = "cable"
bandwidth = 20
protocol_latency = 20
upload-size = 1024
download-size = 2048
name = "latency"
bandwidth = 1000
[[bench]]
group = "cable"
#### Mobile 5G ####
# Modern mobile connection with good coverage
# - Upload bandwidth: 30 Mbps (typical 5G upload in good conditions)
# - Latency: 30ms (higher than wired due to mobile tower hops)
[[group]]
name = "mobile_5g"
bandwidth = 30
protocol_latency = 30
upload-size = 1024
download-size = 2048
group = "latency"
protocol_latency = 10
[[bench]]
group = "mobile_5g"
group = "latency"
protocol_latency = 25
#### Fiber Home Internet ####
# High-end residential connection (best case scenario)
# - Symmetric: equal upload/download bandwidth
# - Upload bandwidth: 100 Mbps (typical fiber upload)
# - Latency: 15ms (lower latency than cable)
[[bench]]
group = "latency"
protocol_latency = 50
[[bench]]
group = "latency"
protocol_latency = 100
[[bench]]
group = "latency"
protocol_latency = 200
#### Bandwidth ####
[[group]]
name = "fiber"
name = "bandwidth"
protocol_latency = 25
[[bench]]
group = "bandwidth"
bandwidth = 10
[[bench]]
group = "bandwidth"
bandwidth = 50
[[bench]]
group = "bandwidth"
bandwidth = 100
protocol_latency = 15
upload-size = 1024
download-size = 2048
[[bench]]
group = "fiber"
group = "bandwidth"
bandwidth = 250
[[bench]]
group = "bandwidth"
bandwidth = 1000

View File

@@ -1,52 +0,0 @@
#### Bandwidth Sweep Benchmark ####
#
# Measures how network bandwidth affects TLSNotary runtime.
# Keeps latency and payload sizes fixed while varying upload bandwidth.
#
# Fixed parameters:
# - Latency: 25ms (typical internet latency)
# - Upload: 1KB (typical request)
# - Download: 2KB (typical response)
#
# Variable: Bandwidth from 5 Mbps to 1000 Mbps
#
# Use this to plot "Bandwidth vs Runtime" and understand bandwidth sensitivity.
# Focus on upload bandwidth as TLSNotary is primarily upload-bottlenecked
[[group]]
name = "bandwidth_sweep"
protocol_latency = 25
upload-size = 1024
download-size = 2048
[[bench]]
group = "bandwidth_sweep"
bandwidth = 5
[[bench]]
group = "bandwidth_sweep"
bandwidth = 10
[[bench]]
group = "bandwidth_sweep"
bandwidth = 20
[[bench]]
group = "bandwidth_sweep"
bandwidth = 50
[[bench]]
group = "bandwidth_sweep"
bandwidth = 100
[[bench]]
group = "bandwidth_sweep"
bandwidth = 250
[[bench]]
group = "bandwidth_sweep"
bandwidth = 500
[[bench]]
group = "bandwidth_sweep"
bandwidth = 1000

View File

@@ -1,53 +0,0 @@
#### Download Size Sweep Benchmark ####
#
# Measures how download payload size affects TLSNotary runtime.
# Keeps network conditions fixed while varying the response size.
#
# Fixed parameters:
# - Bandwidth: 100 Mbps (typical good connection)
# - Latency: 25ms (typical internet latency)
# - Upload: 1KB (typical request size)
#
# Variable: Download size from 1KB to 100KB
#
# Use this to plot "Download Size vs Runtime" and understand how much data
# TLSNotary can efficiently notarize. Useful for determining optimal
# chunking strategies for large responses.
[[group]]
name = "download_sweep"
bandwidth = 100
protocol_latency = 25
upload-size = 1024
[[bench]]
group = "download_sweep"
download-size = 1024
[[bench]]
group = "download_sweep"
download-size = 2048
[[bench]]
group = "download_sweep"
download-size = 5120
[[bench]]
group = "download_sweep"
download-size = 10240
[[bench]]
group = "download_sweep"
download-size = 20480
[[bench]]
group = "download_sweep"
download-size = 30720
[[bench]]
group = "download_sweep"
download-size = 40960
[[bench]]
group = "download_sweep"
download-size = 51200

View File

@@ -1,47 +0,0 @@
#### Latency Sweep Benchmark ####
#
# Measures how network latency affects TLSNotary runtime.
# Keeps bandwidth and payload sizes fixed while varying protocol latency.
#
# Fixed parameters:
# - Bandwidth: 100 Mbps (typical good connection)
# - Upload: 1KB (typical request)
# - Download: 2KB (typical response)
#
# Variable: Protocol latency from 10ms to 200ms
#
# Use this to plot "Latency vs Runtime" and understand latency sensitivity.
[[group]]
name = "latency_sweep"
bandwidth = 100
upload-size = 1024
download-size = 2048
[[bench]]
group = "latency_sweep"
protocol_latency = 10
[[bench]]
group = "latency_sweep"
protocol_latency = 25
[[bench]]
group = "latency_sweep"
protocol_latency = 50
[[bench]]
group = "latency_sweep"
protocol_latency = 75
[[bench]]
group = "latency_sweep"
protocol_latency = 100
[[bench]]
group = "latency_sweep"
protocol_latency = 150
[[bench]]
group = "latency_sweep"
protocol_latency = 200

View File

@@ -3,19 +3,7 @@
# Ensure the script runs in the folder that contains this script
cd "$(dirname "$0")"
RUNNER_FEATURES=""
EXECUTOR_FEATURES=""
if [ "$1" = "debug" ]; then
RUNNER_FEATURES="--features debug"
EXECUTOR_FEATURES="--no-default-features --features debug"
fi
cargo build --release \
--package tlsn-harness-runner $RUNNER_FEATURES \
--package tlsn-harness-executor $EXECUTOR_FEATURES \
--package tlsn-server-fixture \
--package tlsn-harness-plot
cargo build --release --package tlsn-harness-runner --package tlsn-harness-executor --package tlsn-server-fixture --package tlsn-harness-plot
mkdir -p bin

View File

@@ -9,7 +9,6 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
pub const DEFAULT_MEMORY_PROFILE: bool = false;
pub const DEFAULT_REVEAL_ALL: bool = false;
pub const WARM_UP_BENCH: Bench = Bench {
group: None,
@@ -21,7 +20,6 @@ pub const WARM_UP_BENCH: Bench = Bench {
download_size: 4096,
defer_decryption: true,
memory_profile: false,
reveal_all: true,
};
#[derive(Deserialize)]
@@ -81,8 +79,6 @@ pub struct BenchGroupItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -101,8 +97,6 @@ pub struct BenchItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
impl BenchItem {
@@ -138,10 +132,6 @@ impl BenchItem {
if self.memory_profile.is_none() {
self.memory_profile = group.memory_profile;
}
if self.reveal_all.is_none() {
self.reveal_all = group.reveal_all;
}
}
pub fn into_bench(&self) -> Bench {
@@ -155,7 +145,6 @@ impl BenchItem {
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
}
}
}
@@ -175,8 +164,6 @@ pub struct Bench {
pub defer_decryption: bool,
#[serde(rename = "memory-profile")]
pub memory_profile: bool,
#[serde(rename = "reveal-all")]
pub reveal_all: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -22,10 +22,7 @@ pub enum CmdOutput {
GetTests(Vec<String>),
Test(TestOutput),
Bench(BenchOutput),
#[cfg(target_arch = "wasm32")]
Fail {
reason: Option<String>,
},
Fail { reason: Option<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -1,14 +1,10 @@
[target.wasm32-unknown-unknown]
rustflags = [
"-Ctarget-feature=+atomics,+bulk-memory,+mutable-globals,+simd128",
"-Clink-arg=--shared-memory",
"-C",
"target-feature=+atomics,+bulk-memory,+mutable-globals,+simd128",
"-C",
# 4GB
"-Clink-arg=--max-memory=4294967296",
"-Clink-arg=--import-memory",
"-Clink-arg=--export=__wasm_init_tls",
"-Clink-arg=--export=__tls_size",
"-Clink-arg=--export=__tls_align",
"-Clink-arg=--export=__tls_base",
"link-arg=--max-memory=4294967296",
"--cfg",
'getrandom_backend="wasm_js"',
]

View File

@@ -4,12 +4,6 @@ version = "0.1.0"
edition = "2024"
publish = false
[features]
# Disable tracing events as a workaround for issue 959.
default = ["tracing/release_max_level_off"]
# Used to debug the executor itself.
debug = []
[lib]
name = "harness_executor"
crate-type = ["cdylib", "rlib"]
@@ -34,7 +28,8 @@ tokio = { workspace = true, features = ["full"] }
tokio-util = { workspace = true, features = ["compat"] }
[target.'cfg(target_arch = "wasm32")'.dependencies]
tracing = { workspace = true }
# Disable tracing events as a workaround for issue 959.
tracing = { workspace = true, features = ["release_max_level_off"] }
wasm-bindgen = { workspace = true }
tlsn-wasm = { workspace = true }
js-sys = { workspace = true }

View File

@@ -5,15 +5,9 @@ use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
use harness_core::bench::{Bench, ProverMetrics};
use tlsn::{
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
},
config::{CertificateDer, ProtocolConfig, RootCertStore},
connection::ServerName,
prover::Prover,
webpki::{CertificateDer, RootCertStore},
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
};
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -28,47 +22,41 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let sent = verifier_io.sent();
let recv = verifier_io.recv();
let prover = Prover::new(ProverConfig::builder().build()?);
let mut builder = ProtocolConfig::builder();
builder.max_sent_data(config.upload_size);
builder.defer_decryption_from_start(config.defer_decryption);
if !config.defer_decryption {
builder.max_recv_data_online(config.download_size + RECV_PADDING);
}
builder.max_recv_data(config.download_size + RECV_PADDING);
let protocol_config = builder.build()?;
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
let prover = Prover::new(
ProverConfig::builder()
.tls_config(tls_config)
.protocol_config(protocol_config)
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.build()?,
);
let time_start = web_time::Instant::now();
let prover = prover
.commit(
TlsCommitConfig::builder()
.protocol({
let mut builder = MpcTlsConfig::builder()
.max_sent_data(config.upload_size)
.defer_decryption_from_start(config.defer_decryption);
if !config.defer_decryption {
builder = builder.max_recv_data_online(config.download_size + RECV_PADDING);
}
builder
.max_recv_data(config.download_size + RECV_PADDING)
.build()
}?)
.build()?,
verifier_io,
)
.await?;
let prover = prover.setup(verifier_io).await?;
let time_preprocess = time_start.elapsed().as_millis();
let time_start_online = web_time::Instant::now();
let uploaded_preprocess = sent.load(Ordering::Relaxed);
let downloaded_preprocess = recv.load(Ordering::Relaxed);
let (mut conn, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
provider.provide_server_io().await?,
)
.await?;
let (mut conn, prover_fut) = prover.connect(provider.provide_server_io().await?).await?;
let (_, mut prover) = futures::try_join!(
async {
@@ -98,27 +86,14 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let mut builder = ProveConfig::builder(prover.transcript());
// When reveal_all is false (the default), we exclude 1 byte to avoid the
// reveal-all optimization and benchmark the realistic ZK authentication path.
let reveal_sent_range = if config.reveal_all {
0..sent_len
} else {
0..sent_len.saturating_sub(1)
};
let reveal_recv_range = if config.reveal_all {
0..recv_len
} else {
0..recv_len.saturating_sub(1)
};
builder
.server_identity()
.reveal_sent(&reveal_sent_range)?
.reveal_recv(&reveal_recv_range)?;
.reveal_sent(&(0..sent_len))?
.reveal_recv(&(0..recv_len))?;
let prove_config = builder.build()?;
let config = builder.build()?;
prover.prove(&prove_config).await?;
prover.prove(&config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -2,31 +2,33 @@ use anyhow::Result;
use harness_core::bench::Bench;
use tlsn::{
config::verifier::VerifierConfig,
verifier::Verifier,
webpki::{CertificateDer, RootCertStore},
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
verifier::{Verifier, VerifierConfig, VerifyConfig},
};
use tlsn_server_fixture_certs::CA_CERT_DER;
use crate::IoProvider;
use crate::{IoProvider, bench::RECV_PADDING};
pub async fn bench_verifier(provider: &IoProvider, config: &Bench) -> Result<()> {
let mut builder = ProtocolConfigValidator::builder();
builder
.max_sent_data(config.upload_size)
.max_recv_data(config.download_size + RECV_PADDING);
let protocol_config = builder.build()?;
pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()> {
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(protocol_config)
.build()?,
);
let verifier = verifier
.commit(provider.provide_proto_io().await?)
.await?
.accept()
.await?
.run()
.await?;
let (_, verifier) = verifier.verify().await?.accept().await?;
let verifier = verifier.setup(provider.provide_proto_io().await?).await?;
let mut verifier = verifier.run().await?;
verifier.verify(&VerifyConfig::default()).await?;
verifier.close().await?;
Ok(())

Some files were not shown because too many files have changed in this diff Show More