diff --git a/Cargo.lock b/Cargo.lock index 999f755a7..25f6f4414 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2026,7 +2026,7 @@ checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" [[package]] name = "clmul" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "bytemuck", "cfg-if", @@ -4225,7 +4225,7 @@ checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "matrix-transpose" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "thiserror 1.0.69", ] @@ -4282,7 +4282,7 @@ dependencies = [ [[package]] name = "mpz-circuits" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "mpz-circuits-core", "mpz-circuits-data", @@ -4291,7 +4291,7 @@ dependencies = [ [[package]] name = "mpz-circuits-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "bincode 1.3.3", "itybity 0.3.1", @@ -4306,7 +4306,7 @@ dependencies = [ [[package]] name = "mpz-circuits-data" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "bincode 1.3.3", "mpz-circuits-core", @@ -4316,7 +4316,7 @@ dependencies = [ [[package]] name = "mpz-cointoss" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "futures", "mpz-cointoss-core", @@ -4329,7 +4329,7 @@ dependencies = [ [[package]] name = "mpz-cointoss-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "mpz-core", "opaque-debug", @@ -4340,7 +4340,7 @@ dependencies = [ [[package]] name = "mpz-common" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "bytes", @@ -4360,7 +4360,7 @@ dependencies = [ [[package]] name = "mpz-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "aes 0.9.0-rc.2", "bcs", @@ -4386,7 +4386,7 @@ dependencies = [ [[package]] name = "mpz-fields" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "ark-ff 0.4.2", "ark-secp256r1", @@ -4406,7 +4406,7 @@ dependencies = [ [[package]] name = "mpz-garble" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "derive_builder 0.11.2", @@ -4432,7 +4432,7 @@ dependencies = [ [[package]] name = "mpz-garble-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "aes 0.9.0-rc.2", "bitvec", @@ -4463,7 +4463,7 @@ dependencies = [ [[package]] name = "mpz-hash" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "blake3", "itybity 0.3.1", @@ -4476,7 +4476,7 @@ dependencies = [ [[package]] name = "mpz-ideal-vm" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "futures", @@ -4493,7 +4493,7 @@ dependencies = [ [[package]] name = "mpz-memory-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "blake3", "futures", @@ -4508,7 +4508,7 @@ dependencies = [ [[package]] name = "mpz-ole" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "futures", @@ -4526,7 +4526,7 @@ dependencies = [ [[package]] name = "mpz-ole-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "hybrid-array", "itybity 0.3.1", @@ -4542,7 +4542,7 @@ dependencies = [ [[package]] name = "mpz-ot" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "cfg-if", @@ -4565,7 +4565,7 @@ dependencies = [ [[package]] name = "mpz-ot-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "aes 0.9.0-rc.2", "blake3", @@ -4596,7 +4596,7 @@ dependencies = [ [[package]] name = "mpz-predicate" version = "0.1.0-alpha.14-pre" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "bytes", "mpz-circuits", @@ -4612,7 +4612,7 @@ dependencies = [ [[package]] name = "mpz-share-conversion" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "mpz-common", @@ -4628,7 +4628,7 @@ dependencies = [ [[package]] name = "mpz-share-conversion-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "mpz-common", "mpz-core", @@ -4642,7 +4642,7 @@ dependencies = [ [[package]] name = "mpz-vm-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "futures", @@ -4655,7 +4655,7 @@ dependencies = [ [[package]] name = "mpz-zk" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "async-trait", "blake3", @@ -4673,7 +4673,7 @@ dependencies = [ [[package]] name = "mpz-zk-core" version = "0.1.0-alpha.4" -source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fpredicate_on_dev#59e65304af8619ddc69b77b7dd2a9b0ca7219eec" +source = "git+https://github.com/privacy-ethereum/mpz?branch=feat%2Fmpz-bool-type#9405caab5d65be1c7e7dd39c6026a9efb57a598d" dependencies = [ "blake3", "cfg-if", diff --git a/Cargo.toml b/Cargo.toml index b763210df..12a970f28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,21 +66,21 @@ tlsn-harness-runner = { path = "crates/harness/runner" } tlsn-wasm = { path = "crates/wasm" } tlsn = { path = "crates/tlsn" } -mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-common = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-predicate = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } -mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/predicate_on_dev" } +mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-common = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-predicate = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } +mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" } rangeset = { version = "0.4" } serio = { version = "0.2" } diff --git a/crates/tlsn/src/transcript_internal/predicate.rs b/crates/tlsn/src/transcript_internal/predicate.rs index 62e90bcbe..6eb0a261f 100644 --- a/crates/tlsn/src/transcript_internal/predicate.rs +++ b/crates/tlsn/src/transcript_internal/predicate.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use mpz_circuits::Circuit; +use mpz_core::bitvec::BitVec; use mpz_memory_core::{ - MemoryExt, Vector, ViewExt, - binary::{Binary, U8}, + DecodeFutureTyped, MemoryExt, + binary::{Binary, Bool}, }; use mpz_predicate::{Pred, compiler::Compiler}; use mpz_vm_core::{Call, CallableExt, Vm}; @@ -26,6 +27,15 @@ pub(crate) enum PredicateError { /// Circuit call error. #[error("circuit call error: {0}")] Call(#[from] mpz_vm_core::CallError), + /// Decode error. + #[error("decode error: {0}")] + Decode(#[from] mpz_memory_core::DecodeError), + /// Missing decoding. + #[error("missing decoding")] + MissingDecoding, + /// Predicate not satisfied. + #[error("predicate evaluated to false")] + PredicateNotSatisfied, } /// Converts a slice of indices to a RangeSet (each index becomes a single-byte @@ -58,17 +68,48 @@ pub(crate) fn prove_predicates>( // Get indices from the predicate and convert to RangeSet let indices = indices_to_rangeset(&predicate.indices()); - execute_predicate(vm, refs, &indices, &circuit)?; + // Prover doesn't need to verify output - they know their data satisfies the predicate + let _ = execute_predicate(vm, refs, &indices, &circuit)?; } Ok(()) } +/// Proof that predicates were satisfied. +/// +/// Must be verified after `vm.execute_all()` completes. +#[must_use] +pub(crate) struct PredicateProof { + /// Decode futures for each predicate output. + outputs: Vec>, +} + +impl PredicateProof { + /// Verifies that all predicates evaluated to true. + /// + /// Must be called after `vm.execute_all()` completes. + pub(crate) fn verify(self) -> Result<(), PredicateError> { + for mut output in self.outputs { + let result = output + .try_recv() + .map_err(PredicateError::Decode)? + .ok_or(PredicateError::MissingDecoding)?; + + if !result { + return Err(PredicateError::PredicateNotSatisfied); + } + } + Ok(()) + } +} + /// Verifies predicates over transcript data (verifier side). /// /// The verifier must provide the same predicates that the prover used, /// looked up by predicate name from out-of-band agreement. /// +/// Returns a [`PredicateProof`] that must be verified after `vm.execute_all()`. +/// /// # Arguments /// /// * `vm` - The zkVM. @@ -78,8 +119,9 @@ pub(crate) fn verify_predicates>( vm: &mut T, transcript_refs: &TranscriptRefs, predicates: impl IntoIterator, Pred)>, -) -> Result<(), PredicateError> { +) -> Result { let mut compiler = Compiler::new(); + let mut outputs = Vec::new(); for (direction, indices, predicate) in predicates { let refs = match direction { @@ -90,19 +132,22 @@ pub(crate) fn verify_predicates>( // Compile predicate to circuit let circuit = compiler.compile(&predicate); - execute_predicate(vm, refs, &indices, &circuit)?; + let output_fut = execute_predicate(vm, refs, &indices, &circuit)?; + outputs.push(output_fut); } - Ok(()) + Ok(PredicateProof { outputs }) } /// Executes a predicate circuit with transcript bytes as input. +/// +/// Returns a decode future for the circuit output. fn execute_predicate>( vm: &mut T, refs: &ReferenceMap, indices: &RangeSet, circuit: &Circuit, -) -> Result<(), PredicateError> { +) -> Result, PredicateError> { // Get the transcript bytes for the predicate indices let indexed_refs = refs .index(indices) @@ -121,23 +166,10 @@ fn execute_predicate>( let call = call_builder.build()?; - // Execute the circuit - output is a single bit (bool) - let output: Vector = vm.call(call)?; + // Execute the circuit - output is a single bit (true/false) + // Both parties must call decode() on the output to reveal it + let output: Bool = vm.call(call)?; - // The output should be a single bit indicating predicate satisfaction. - // We mark it public so both parties can see the result. - vm.mark_public(output)?; - vm.commit(output)?; - - // Decode the result to verify it's true - let result_fut = vm.decode(output)?; - - // Note: The actual verification that the output is true happens during - // execution. If the predicate is false, the ZK proof will fail. - - // Drop the future - we don't need to wait for it here since execute_all - // will handle this - drop(result_fut); - - Ok(()) + // Return decode future - caller must verify output == true after execute_all + Ok(vm.decode(output)?) } diff --git a/crates/tlsn/src/verifier/verify.rs b/crates/tlsn/src/verifier/verify.rs index 3b6603460..7d2ff52df 100644 --- a/crates/tlsn/src/verifier/verify.rs +++ b/crates/tlsn/src/verifier/verify.rs @@ -174,7 +174,7 @@ pub(crate) async fn verify + KeyStore + Send + Sync>( } // Verify predicates if any were requested - if !request.predicates().is_empty() { + let predicate_proof = if !request.predicates().is_empty() { let resolver = predicate_resolver.ok_or_else(|| { VerifierError::predicate("predicates requested but no resolver provided") })?; @@ -190,14 +190,21 @@ pub(crate) async fn verify + KeyStore + Send + Sync>( }) .collect::, VerifierError>>()?; - verify_predicates(vm, &transcript_refs, predicates).map_err(VerifierError::predicate)?; - } + Some(verify_predicates(vm, &transcript_refs, predicates).map_err(VerifierError::predicate)?) + } else { + None + }; vm.execute_all(ctx).await.map_err(VerifierError::zk)?; sent_proof.verify().map_err(VerifierError::verify)?; recv_proof.verify().map_err(VerifierError::verify)?; + // Verify predicate outputs after ZK execution + if let Some(proof) = predicate_proof { + proof.verify().map_err(VerifierError::predicate)?; + } + let mut encoder_secret = None; if let Some(commit_config) = request.transcript_commit() && let Some((sent, recv)) = commit_config.encoding() diff --git a/crates/tlsn/tests/test.rs b/crates/tlsn/tests/test.rs index 22ddbdde3..5f2e4f71a 100644 --- a/crates/tlsn/tests/test.rs +++ b/crates/tlsn/tests/test.rs @@ -1,4 +1,5 @@ use futures::{AsyncReadExt, AsyncWriteExt}; +use mpz_predicate::{Pred, eq}; use rangeset::set::RangeSet; use tlsn::{ config::{ @@ -231,3 +232,184 @@ async fn verifier( output } + +// Predicate name for testing +const TEST_PREDICATE: &str = "test_first_byte"; + +/// Test that a correct predicate passes verification. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[ignore] +async fn test_predicate_passes() { + let (socket_0, socket_1) = tokio::io::duplex(2 << 23); + + // Request is "GET / HTTP/1.1\r\n..." - index 10 is '/' (in "HTTP/1.1") + // Using index 10 to avoid overlap with revealed range (0..10) + // Verifier uses the same predicate - should pass + let prover_predicate = eq(10, b'/'); + + let (prover_result, verifier_result) = tokio::join!( + prover_with_predicate(socket_0, prover_predicate), + verifier_with_predicate(socket_1, || eq(10, b'/')) + ); + + prover_result.expect("prover should succeed"); + verifier_result.expect("verifier should succeed with correct predicate"); +} + +/// Test that a wrong predicate is rejected by the verifier. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[ignore] +async fn test_wrong_predicate_rejected() { + let (socket_0, socket_1) = tokio::io::duplex(2 << 23); + + // Request is "GET / HTTP/1.1\r\n..." - index 10 is '/' + // Verifier uses a DIFFERENT predicate that checks for 'X' - should fail + let prover_predicate = eq(10, b'/'); + + let (prover_result, verifier_result) = tokio::join!( + prover_with_predicate(socket_0, prover_predicate), + verifier_with_predicate(socket_1, || eq(10, b'X')) + ); + + // Prover may succeed or fail depending on when verifier rejects + let _ = prover_result; + + // Verifier should fail because predicate evaluates to false + assert!( + verifier_result.is_err(), + "verifier should reject wrong predicate" + ); +} + +/// Test that prover can't prove a predicate their data doesn't satisfy. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[ignore] +async fn test_unsatisfied_predicate_rejected() { + let (socket_0, socket_1) = tokio::io::duplex(2 << 23); + + // Request is "GET / HTTP/1.1\r\n..." - index 10 is '/' + // Both parties use eq(10, b'X') but prover's data has '/' at index 10 + // This tests that a prover can't cheat - the predicate must actually be satisfied + let prover_predicate = eq(10, b'X'); + + let (prover_result, verifier_result) = tokio::join!( + prover_with_predicate(socket_0, prover_predicate), + verifier_with_predicate(socket_1, || eq(10, b'X')) + ); + + // Prover may succeed or fail depending on when verifier rejects + let _ = prover_result; + + // Verifier should fail because prover's data doesn't satisfy the predicate + assert!( + verifier_result.is_err(), + "verifier should reject unsatisfied predicate" + ); +} + +#[instrument(skip(verifier_socket, predicate))] +async fn prover_with_predicate( + verifier_socket: T, + predicate: Pred, +) -> Result<(), Box> { + let (client_socket, server_socket) = tokio::io::duplex(2 << 16); + + let server_task = tokio::spawn(bind(server_socket.compat())); + + let prover = Prover::new(ProverConfig::builder().build()?) + .commit( + TlsCommitConfig::builder() + .protocol( + MpcTlsConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_sent_records(MAX_SENT_RECORDS) + .max_recv_data(MAX_RECV_DATA) + .max_recv_records_online(MAX_RECV_RECORDS) + .build()?, + ) + .build()?, + verifier_socket.compat(), + ) + .await?; + + let (mut tls_connection, prover_fut) = prover + .connect( + TlsClientConfig::builder() + .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) + .root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }) + .build()?, + client_socket.compat(), + ) + .await?; + let prover_task = tokio::spawn(prover_fut); + + tls_connection + .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + .await?; + tls_connection.close().await?; + + let mut response = vec![0u8; 1024]; + tls_connection.read_to_end(&mut response).await?; + + let _ = server_task.await?; + + let mut prover = prover_task.await??; + + let mut builder = ProveConfig::builder(prover.transcript()); + builder.server_identity(); + builder.reveal_sent(&(0..10))?; + builder.reveal_recv(&(0..10))?; + builder.predicate(TEST_PREDICATE, Direction::Sent, predicate)?; + + let config = builder.build()?; + prover.prove(&config).await?; + prover.close().await?; + + Ok(()) +} + +async fn verifier_with_predicate( + socket: T, + make_predicate: F, +) -> Result> +where + T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static, + F: Fn() -> Pred + Send + Sync + 'static, +{ + let verifier = Verifier::new( + VerifierConfig::builder() + .root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }) + .build()?, + ); + + let verifier = verifier + .commit(socket.compat()) + .await? + .accept() + .await? + .run() + .await?; + + let verifier = verifier.verify().await?; + + // Resolver that builds the predicate fresh (Pred uses Rc, so can't be shared) + let predicate_resolver = move |name: &str, _indices: &RangeSet| -> Option { + if name == TEST_PREDICATE { + Some(make_predicate()) + } else { + None + } + }; + + let (output, verifier) = verifier + .accept_with_predicates(Some(&predicate_resolver)) + .await?; + + verifier.close().await?; + + Ok(output) +}