mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-11 15:47:58 -05:00
Compare commits
1 Commits
dev
...
refactor/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a086ee91f |
129
Cargo.lock
generated
129
Cargo.lock
generated
@@ -174,9 +174,9 @@ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
||||
|
||||
[[package]]
|
||||
name = "alloy-consensus"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b6440213a22df93a87ed512d2f668e7dc1d62a05642d107f82d61edc9e12370"
|
||||
checksum = "2e318e25fb719e747a7e8db1654170fc185024f3ed5b10f86c08d448a912f6e2"
|
||||
dependencies = [
|
||||
"alloy-eips",
|
||||
"alloy-primitives",
|
||||
@@ -201,9 +201,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-consensus-any"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15d0bea09287942405c4f9d2a4f22d1e07611c2dbd9d5bf94b75366340f9e6e0"
|
||||
checksum = "364380a845193a317bcb7a5398fc86cdb66c47ebe010771dde05f6869bf9e64a"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-eips",
|
||||
@@ -253,9 +253,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-eips"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4bd2c7ae05abcab4483ce821f12f285e01c0b33804e6883dd9ca1569a87ee2be"
|
||||
checksum = "a4c4d7c5839d9f3a467900c625416b24328450c65702eb3d8caff8813e4d1d33"
|
||||
dependencies = [
|
||||
"alloy-eip2124",
|
||||
"alloy-eip2930",
|
||||
@@ -288,9 +288,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-json-rpc"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "003f46c54f22854a32b9cc7972660a476968008ad505427eabab49225309ec40"
|
||||
checksum = "f72cf87cda808e593381fb9f005ffa4d2475552b7a6c5ac33d087bf77d82abd0"
|
||||
dependencies = [
|
||||
"alloy-primitives",
|
||||
"alloy-sol-types",
|
||||
@@ -303,9 +303,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-network"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f4029954d9406a40979f3a3b46950928a0fdcfe3ea8a9b0c17490d57e8aa0e3"
|
||||
checksum = "12aeb37b6f2e61b93b1c3d34d01ee720207c76fe447e2a2c217e433ac75b17f5"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-consensus-any",
|
||||
@@ -329,9 +329,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-network-primitives"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7805124ad69e57bbae7731c9c344571700b2a18d351bda9e0eba521c991d1bcb"
|
||||
checksum = "abd29ace62872083e30929cd9b282d82723196d196db589f3ceda67edcc05552"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-eips",
|
||||
@@ -391,9 +391,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-rpc-types-any"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b43c1622aac2508d528743fd4cfdac1dea92d5a8fa894038488ff7edd0af0b32"
|
||||
checksum = "6a63fb40ed24e4c92505f488f9dd256e2afaed17faa1b7a221086ebba74f4122"
|
||||
dependencies = [
|
||||
"alloy-consensus-any",
|
||||
"alloy-rpc-types-eth",
|
||||
@@ -402,9 +402,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-rpc-types-eth"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed5fafb741c19b3cca4cdd04fa215c89413491f9695a3e928dee2ae5657f607e"
|
||||
checksum = "9eae0c7c40da20684548cbc8577b6b7447f7bf4ddbac363df95e3da220e41e72"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-consensus-any",
|
||||
@@ -423,9 +423,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-serde"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a6f180c399ca7c1e2fe17ea58343910cad0090878a696ff5a50241aee12fc529"
|
||||
checksum = "c0df1987ed0ff2d0159d76b52e7ddfc4e4fbddacc54d2fbee765e0d14d7c01b5"
|
||||
dependencies = [
|
||||
"alloy-primitives",
|
||||
"serde",
|
||||
@@ -434,9 +434,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-signer"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ecc39ad2c0a3d2da8891f4081565780703a593f090f768f884049aa3aa929cbc"
|
||||
checksum = "6ff69deedee7232d7ce5330259025b868c5e6a52fa8dffda2c861fb3a5889b24"
|
||||
dependencies = [
|
||||
"alloy-primitives",
|
||||
"async-trait",
|
||||
@@ -449,9 +449,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-signer-local"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "930e17cb1e46446a193a593a3bfff8d0ecee4e510b802575ebe300ae2e43ef75"
|
||||
checksum = "72cfe0be3ec5a8c1a46b2e5a7047ed41121d360d97f4405bb7c1c784880c86cb"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-network",
|
||||
@@ -551,9 +551,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "alloy-tx-macros"
|
||||
version = "1.1.2"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae109e33814b49fc0a62f2528993aa8a2dd346c26959b151f05441dc0b9da292"
|
||||
checksum = "333544408503f42d7d3792bfc0f7218b643d968a03d2c0ed383ae558fb4a76d0"
|
||||
dependencies = [
|
||||
"darling 0.21.3",
|
||||
"proc-macro2",
|
||||
@@ -1783,9 +1783,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.48"
|
||||
version = "1.2.49"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a"
|
||||
checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215"
|
||||
dependencies = [
|
||||
"find-msvc-tools",
|
||||
"shlex",
|
||||
@@ -3220,6 +3220,14 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-plex"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/tlsnotary/tlsn-utils?rev=c210f2f#c210f2fdd0a5d71c3e217fa03127c9f616314836"
|
||||
dependencies = [
|
||||
"futures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-rustls"
|
||||
version = "0.25.1"
|
||||
@@ -3766,9 +3774,9 @@ checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
|
||||
|
||||
[[package]]
|
||||
name = "icu_properties"
|
||||
version = "2.1.1"
|
||||
version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99"
|
||||
checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec"
|
||||
dependencies = [
|
||||
"icu_collections",
|
||||
"icu_locale_core",
|
||||
@@ -3780,9 +3788,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "icu_properties_data"
|
||||
version = "2.1.1"
|
||||
version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899"
|
||||
checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af"
|
||||
|
||||
[[package]]
|
||||
name = "icu_provider"
|
||||
@@ -4302,9 +4310,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873"
|
||||
checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"wasi",
|
||||
@@ -5486,7 +5494,7 @@ version = "3.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983"
|
||||
dependencies = [
|
||||
"toml_edit 0.23.7",
|
||||
"toml_edit 0.23.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5901,9 +5909,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.24"
|
||||
version = "0.12.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f"
|
||||
checksum = "b6eff9328d40131d43bd911d42d79eb6a47312002a4daefc9e37f17e74a7701a"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
@@ -5930,7 +5938,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower",
|
||||
"tower-http 0.6.7",
|
||||
"tower-http 0.6.8",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
@@ -6788,9 +6796,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "simd-adler32"
|
||||
version = "0.3.7"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
|
||||
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
|
||||
|
||||
[[package]]
|
||||
name = "sized-chunks"
|
||||
@@ -6854,9 +6862,9 @@ checksum = "bceb57dc07c92cdae60f5b27b3fa92ecaaa42fe36c55e22dbfb0b44893e0b1f7"
|
||||
|
||||
[[package]]
|
||||
name = "sourcemap"
|
||||
version = "9.2.2"
|
||||
version = "9.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e22afbcb92ce02d23815b9795523c005cb9d3c214f8b7a66318541c240ea7935"
|
||||
checksum = "37ccaaa78a0ca68b20f8f711eaa2522a00131c48a3de5b892ca5c36cec1ce9bb"
|
||||
dependencies = [
|
||||
"base64-simd",
|
||||
"bitvec",
|
||||
@@ -7032,9 +7040,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sys_traits"
|
||||
version = "0.1.19"
|
||||
version = "0.1.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1495a604cd38eeb30c408724966cd31ca1b68b5a97e3afc474c0d719bfeec5a"
|
||||
checksum = "6b61f4a25d0baba25511bed00c39c199d9a19cfd8107f4472724b72a84f530b1"
|
||||
dependencies = [
|
||||
"sys_traits_macros",
|
||||
]
|
||||
@@ -7255,6 +7263,7 @@ dependencies = [
|
||||
"aes 0.8.4",
|
||||
"ctr 0.9.2",
|
||||
"futures",
|
||||
"futures-plex",
|
||||
"ghash 0.5.1",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
@@ -7272,6 +7281,7 @@ dependencies = [
|
||||
"mpz-zk",
|
||||
"once_cell",
|
||||
"opaque-debug",
|
||||
"pin-project-lite",
|
||||
"rand 0.9.2",
|
||||
"rangeset 0.4.0",
|
||||
"rstest",
|
||||
@@ -7289,7 +7299,6 @@ dependencies = [
|
||||
"tlsn-server-fixture",
|
||||
"tlsn-server-fixture-certs",
|
||||
"tlsn-tls-client",
|
||||
"tlsn-tls-client-async",
|
||||
"tlsn-tls-core",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
@@ -7314,7 +7323,6 @@ dependencies = [
|
||||
"p256",
|
||||
"rand 0.9.2",
|
||||
"rand06-compat",
|
||||
"rangeset 0.4.0",
|
||||
"rstest",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
@@ -7609,7 +7617,6 @@ dependencies = [
|
||||
"tlsn-key-exchange",
|
||||
"tlsn-tls-backend",
|
||||
"tlsn-tls-client",
|
||||
"tlsn-tls-client-async",
|
||||
"tlsn-tls-core",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
@@ -7636,7 +7643,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower-http 0.6.7",
|
||||
"tower-http 0.6.8",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
@@ -7684,26 +7691,6 @@ dependencies = [
|
||||
"webpki-roots 1.0.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tlsn-tls-client-async"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"rstest",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.103.8",
|
||||
"thiserror 1.0.69",
|
||||
"tls-server-fixture",
|
||||
"tlsn-tls-client",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tlsn-tls-core"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
@@ -7732,6 +7719,7 @@ source = "git+https://github.com/tlsnotary/tlsn-utils?rev=6168663#6168663495281f
|
||||
name = "tlsn-wasm"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
dependencies = [
|
||||
"async_io_stream",
|
||||
"bincode 1.3.3",
|
||||
"console_error_panic_hook",
|
||||
"enum-try-as-inner",
|
||||
@@ -7750,7 +7738,6 @@ dependencies = [
|
||||
"tlsn",
|
||||
"tlsn-core",
|
||||
"tlsn-server-fixture-certs",
|
||||
"tlsn-tls-client-async",
|
||||
"tlsn-tls-core",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
@@ -7897,9 +7884,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.23.7"
|
||||
version = "0.23.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d"
|
||||
checksum = "5d7cbc3b4b49633d57a0509303158ca50de80ae32c265093b24c414705807832"
|
||||
dependencies = [
|
||||
"indexmap 2.12.1",
|
||||
"toml_datetime 0.7.3",
|
||||
@@ -7965,9 +7952,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.6.7"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456"
|
||||
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"bytes",
|
||||
|
||||
@@ -13,7 +13,6 @@ members = [
|
||||
"crates/server-fixture/server",
|
||||
"crates/tls/backend",
|
||||
"crates/tls/client",
|
||||
"crates/tls/client-async",
|
||||
"crates/tls/core",
|
||||
"crates/mpc-tls",
|
||||
"crates/tls/server-fixture",
|
||||
@@ -57,7 +56,6 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
|
||||
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
|
||||
tlsn-tls-backend = { path = "crates/tls/backend" }
|
||||
tlsn-tls-client = { path = "crates/tls/client" }
|
||||
tlsn-tls-client-async = { path = "crates/tls/client-async" }
|
||||
tlsn-tls-core = { path = "crates/tls/core" }
|
||||
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
tlsn-harness-core = { path = "crates/harness/core" }
|
||||
@@ -82,6 +80,7 @@ mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
|
||||
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "c210f2f" }
|
||||
rangeset = { version = "0.4" }
|
||||
serio = { version = "0.2" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
|
||||
|
||||
@@ -27,7 +27,6 @@ alloy-primitives = { version = "1.3.1", default-features = false }
|
||||
alloy-signer = { version = "1.0", default-features = false }
|
||||
alloy-signer-local = { version = "1.0", default-features = false }
|
||||
rand06-compat = { workspace = true }
|
||||
rangeset = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tlsn-core = { workspace = true, features = ["fixtures"] }
|
||||
tlsn-data-fixtures = { workspace = true }
|
||||
|
||||
@@ -5,7 +5,7 @@ use rand::{Rng, rng};
|
||||
use tlsn_core::{
|
||||
connection::{ConnectionInfo, ServerEphemKey},
|
||||
hash::HashAlgId,
|
||||
transcript::TranscriptCommitment,
|
||||
transcript::{TranscriptCommitment, encoding::EncoderSecret},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -25,6 +25,7 @@ pub struct Sign {
|
||||
connection_info: Option<ConnectionInfo>,
|
||||
server_ephemeral_key: Option<ServerEphemKey>,
|
||||
cert_commitment: ServerCertCommitment,
|
||||
encoder_secret: Option<EncoderSecret>,
|
||||
extensions: Vec<Extension>,
|
||||
transcript_commitments: Vec<TranscriptCommitment>,
|
||||
}
|
||||
@@ -86,6 +87,7 @@ impl<'a> AttestationBuilder<'a, Accept> {
|
||||
connection_info: None,
|
||||
server_ephemeral_key: None,
|
||||
cert_commitment,
|
||||
encoder_secret: None,
|
||||
transcript_commitments: Vec::new(),
|
||||
extensions,
|
||||
},
|
||||
@@ -106,6 +108,12 @@ impl AttestationBuilder<'_, Sign> {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the secret for encoding commitments.
|
||||
pub fn encoder_secret(&mut self, secret: EncoderSecret) -> &mut Self {
|
||||
self.state.encoder_secret = Some(secret);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds an extension to the attestation.
|
||||
pub fn extension(&mut self, extension: Extension) -> &mut Self {
|
||||
self.state.extensions.push(extension);
|
||||
@@ -129,6 +137,7 @@ impl AttestationBuilder<'_, Sign> {
|
||||
connection_info,
|
||||
server_ephemeral_key,
|
||||
cert_commitment,
|
||||
encoder_secret,
|
||||
extensions,
|
||||
transcript_commitments,
|
||||
} = self.state;
|
||||
@@ -159,6 +168,7 @@ impl AttestationBuilder<'_, Sign> {
|
||||
AttestationBuilderError::new(ErrorKind::Field, "handshake data was not set")
|
||||
})?),
|
||||
cert_commitment: field_id.next(cert_commitment),
|
||||
encoder_secret: encoder_secret.map(|secret| field_id.next(secret)),
|
||||
extensions: extensions
|
||||
.into_iter()
|
||||
.map(|extension| field_id.next(extension))
|
||||
@@ -243,7 +253,8 @@ mod test {
|
||||
use rstest::{fixture, rstest};
|
||||
use tlsn_core::{
|
||||
connection::{CertBinding, CertBindingV1_2},
|
||||
fixtures::ConnectionFixture,
|
||||
fixtures::{ConnectionFixture, encoding_provider},
|
||||
hash::Blake3,
|
||||
transcript::Transcript,
|
||||
};
|
||||
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
@@ -274,7 +285,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection,
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation_config = AttestationConfig::builder()
|
||||
.supported_signature_algs([SignatureAlgId::SECP256R1])
|
||||
@@ -293,7 +310,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection,
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation_config = AttestationConfig::builder()
|
||||
.supported_signature_algs([SignatureAlgId::SECP256K1])
|
||||
@@ -313,7 +336,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(transcript, connection, Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection,
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation_builder = Attestation::builder(attestation_config)
|
||||
.accept_request(request)
|
||||
@@ -334,8 +363,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let mut attestation_builder = Attestation::builder(attestation_config)
|
||||
.accept_request(request)
|
||||
@@ -359,8 +393,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let mut attestation_builder = Attestation::builder(attestation_config)
|
||||
.accept_request(request)
|
||||
@@ -393,7 +432,9 @@ mod test {
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
vec![Extension {
|
||||
id: b"foo".to_vec(),
|
||||
value: b"bar".to_vec(),
|
||||
@@ -420,7 +461,9 @@ mod test {
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
vec![Extension {
|
||||
id: b"foo".to_vec(),
|
||||
value: b"bar".to_vec(),
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
use tlsn_core::{
|
||||
connection::{CertBinding, CertBindingV1_2},
|
||||
fixtures::ConnectionFixture,
|
||||
transcript::{Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment},
|
||||
hash::HashAlgorithm,
|
||||
transcript::{
|
||||
Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment,
|
||||
encoding::{EncodingProvider, EncodingTree},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -17,13 +21,16 @@ use crate::{
|
||||
/// A Request fixture used for testing.
|
||||
#[allow(missing_docs)]
|
||||
pub struct RequestFixture {
|
||||
pub encoding_tree: EncodingTree,
|
||||
pub request: Request,
|
||||
}
|
||||
|
||||
/// Returns a request fixture for testing.
|
||||
pub fn request_fixture(
|
||||
transcript: Transcript,
|
||||
encodings_provider: impl EncodingProvider,
|
||||
connection: ConnectionFixture,
|
||||
encoding_hasher: impl HashAlgorithm,
|
||||
extensions: Vec<Extension>,
|
||||
) -> RequestFixture {
|
||||
let provider = CryptoProvider::default();
|
||||
@@ -43,9 +50,15 @@ pub fn request_fixture(
|
||||
.unwrap();
|
||||
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
|
||||
|
||||
let mut builder = RequestConfig::builder();
|
||||
// Prover constructs encoding tree.
|
||||
let encoding_tree = EncodingTree::new(
|
||||
&encoding_hasher,
|
||||
transcripts_commitment_config.iter_encoding(),
|
||||
&encodings_provider,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
builder.transcript_commit(transcripts_commitment_config);
|
||||
let mut builder = RequestConfig::builder();
|
||||
|
||||
for extension in extensions {
|
||||
builder.extension(extension);
|
||||
@@ -61,7 +74,10 @@ pub fn request_fixture(
|
||||
|
||||
let (request, _) = request_builder.build(&provider).unwrap();
|
||||
|
||||
RequestFixture { request }
|
||||
RequestFixture {
|
||||
encoding_tree,
|
||||
request,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an attestation fixture for testing.
|
||||
|
||||
@@ -79,6 +79,8 @@
|
||||
//!
|
||||
//! // Specify all the transcript commitments we want to make.
|
||||
//! builder
|
||||
//! // Use BLAKE3 for encoding commitments.
|
||||
//! .encoding_hash_alg(HashAlgId::BLAKE3)
|
||||
//! // Commit to all sent data.
|
||||
//! .commit_sent(&(0..sent_len))?
|
||||
//! // Commit to the first 10 bytes of sent data.
|
||||
@@ -127,7 +129,7 @@
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # use tlsn_attestation::{Attestation, CryptoProvider, Secrets, presentation::Presentation};
|
||||
//! # use tlsn_core::transcript::Direction;
|
||||
//! # use tlsn_core::transcript::{TranscriptCommitmentKind, Direction};
|
||||
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
//! # let attestation: Attestation = unimplemented!();
|
||||
//! # let secrets: Secrets = unimplemented!();
|
||||
@@ -138,6 +140,8 @@
|
||||
//! let mut builder = secrets.transcript_proof_builder();
|
||||
//!
|
||||
//! builder
|
||||
//! // Use transcript encoding commitments.
|
||||
//! .commitment_kinds(&[TranscriptCommitmentKind::Encoding])
|
||||
//! // Disclose the first 10 bytes of the sent data.
|
||||
//! .reveal(&(0..10), Direction::Sent)?
|
||||
//! // Disclose all of the received data.
|
||||
@@ -215,7 +219,7 @@ use tlsn_core::{
|
||||
connection::{ConnectionInfo, ServerEphemKey},
|
||||
hash::{Hash, HashAlgorithm, TypedHash},
|
||||
merkle::MerkleTree,
|
||||
transcript::TranscriptCommitment,
|
||||
transcript::{TranscriptCommitment, encoding::EncoderSecret},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -297,6 +301,8 @@ pub enum FieldKind {
|
||||
ServerEphemKey = 0x02,
|
||||
/// Server identity commitment.
|
||||
ServerIdentityCommitment = 0x03,
|
||||
/// Encoding commitment.
|
||||
EncodingCommitment = 0x04,
|
||||
/// Plaintext hash commitment.
|
||||
PlaintextHash = 0x05,
|
||||
}
|
||||
@@ -321,6 +327,7 @@ pub struct Body {
|
||||
connection_info: Field<ConnectionInfo>,
|
||||
server_ephemeral_key: Field<ServerEphemKey>,
|
||||
cert_commitment: Field<ServerCertCommitment>,
|
||||
encoder_secret: Option<Field<EncoderSecret>>,
|
||||
extensions: Vec<Field<Extension>>,
|
||||
transcript_commitments: Vec<Field<TranscriptCommitment>>,
|
||||
}
|
||||
@@ -366,6 +373,7 @@ impl Body {
|
||||
connection_info: conn_info,
|
||||
server_ephemeral_key,
|
||||
cert_commitment,
|
||||
encoder_secret,
|
||||
extensions,
|
||||
transcript_commitments,
|
||||
} = self;
|
||||
@@ -383,6 +391,13 @@ impl Body {
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(encoder_secret) = encoder_secret {
|
||||
fields.push((
|
||||
encoder_secret.id,
|
||||
hasher.hash_separated(&encoder_secret.data),
|
||||
));
|
||||
}
|
||||
|
||||
for field in extensions.iter() {
|
||||
fields.push((field.id, hasher.hash_separated(&field.data)));
|
||||
}
|
||||
|
||||
@@ -91,6 +91,11 @@ impl Presentation {
|
||||
transcript.verify_with_provider(
|
||||
&provider.hash,
|
||||
&attestation.body.connection_info().transcript_length,
|
||||
attestation
|
||||
.body
|
||||
.encoder_secret
|
||||
.as_ref()
|
||||
.map(|field| &field.data),
|
||||
attestation.body.transcript_commitments(),
|
||||
)
|
||||
})
|
||||
|
||||
@@ -144,7 +144,9 @@ impl std::fmt::Display for ErrorKind {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use tlsn_core::{
|
||||
connection::TranscriptLength, fixtures::ConnectionFixture, hash::HashAlgId,
|
||||
connection::TranscriptLength,
|
||||
fixtures::{ConnectionFixture, encoding_provider},
|
||||
hash::{Blake3, HashAlgId},
|
||||
transcript::Transcript,
|
||||
};
|
||||
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
@@ -162,8 +164,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
@@ -178,8 +185,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { mut request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { mut request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
@@ -197,8 +209,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { mut request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { mut request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
@@ -216,8 +233,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { mut request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { mut request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
@@ -243,8 +265,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let mut attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
@@ -262,8 +289,13 @@ mod test {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } =
|
||||
request_fixture(transcript, connection.clone(), Vec::new());
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation =
|
||||
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
|
||||
|
||||
@@ -49,4 +49,6 @@ impl_domain_separator!(tlsn_core::connection::ConnectionInfo);
|
||||
impl_domain_separator!(tlsn_core::connection::CertBinding);
|
||||
impl_domain_separator!(tlsn_core::transcript::TranscriptCommitment);
|
||||
impl_domain_separator!(tlsn_core::transcript::TranscriptSecret);
|
||||
impl_domain_separator!(tlsn_core::transcript::encoding::EncoderSecret);
|
||||
impl_domain_separator!(tlsn_core::transcript::encoding::EncodingCommitment);
|
||||
impl_domain_separator!(tlsn_core::transcript::hash::PlaintextHash);
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_attestation::{
|
||||
Attestation, AttestationConfig, CryptoProvider,
|
||||
presentation::PresentationOutput,
|
||||
@@ -8,11 +6,12 @@ use tlsn_attestation::{
|
||||
};
|
||||
use tlsn_core::{
|
||||
connection::{CertBinding, CertBindingV1_2},
|
||||
fixtures::ConnectionFixture,
|
||||
hash::{Blake3, Blinder, HashAlgId},
|
||||
fixtures::{self, ConnectionFixture, encoder_secret},
|
||||
hash::Blake3,
|
||||
transcript::{
|
||||
Direction, Transcript, TranscriptCommitment, TranscriptSecret,
|
||||
hash::{PlaintextHash, PlaintextHashSecret, hash_plaintext},
|
||||
Direction, Transcript, TranscriptCommitConfigBuilder, TranscriptCommitment,
|
||||
TranscriptSecret,
|
||||
encoding::{EncodingCommitment, EncodingTree},
|
||||
},
|
||||
};
|
||||
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
@@ -20,7 +19,6 @@ use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
/// Tests that the attestation protocol and verification work end-to-end
|
||||
#[test]
|
||||
fn test_api() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let mut provider = CryptoProvider::default();
|
||||
|
||||
// Configure signer for Notary
|
||||
@@ -28,6 +26,8 @@ fn test_api() {
|
||||
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let (sent_len, recv_len) = transcript.len();
|
||||
// Plaintext encodings which the Prover obtained from GC evaluation
|
||||
let encodings_provider = fixtures::encoding_provider(GET_WITH_HEADER, OK_JSON);
|
||||
|
||||
// At the end of the TLS connection the Prover holds the:
|
||||
let ConnectionFixture {
|
||||
@@ -44,38 +44,26 @@ fn test_api() {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
// Create hash commitments
|
||||
let hasher = Blake3::default();
|
||||
let sent_blinder: Blinder = rng.random();
|
||||
let recv_blinder: Blinder = rng.random();
|
||||
// Prover specifies the ranges it wants to commit to.
|
||||
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
|
||||
transcript_commitment_builder
|
||||
.commit_sent(&(0..sent_len))
|
||||
.unwrap()
|
||||
.commit_recv(&(0..recv_len))
|
||||
.unwrap();
|
||||
|
||||
let sent_idx = RangeSet::from(0..sent_len);
|
||||
let recv_idx = RangeSet::from(0..recv_len);
|
||||
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
|
||||
|
||||
let sent_hash_commitment = PlaintextHash {
|
||||
direction: Direction::Sent,
|
||||
idx: sent_idx.clone(),
|
||||
hash: hash_plaintext(&hasher, transcript.sent(), &sent_blinder),
|
||||
};
|
||||
// Prover constructs encoding tree.
|
||||
let encoding_tree = EncodingTree::new(
|
||||
&Blake3::default(),
|
||||
transcripts_commitment_config.iter_encoding(),
|
||||
&encodings_provider,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let recv_hash_commitment = PlaintextHash {
|
||||
direction: Direction::Received,
|
||||
idx: recv_idx.clone(),
|
||||
hash: hash_plaintext(&hasher, transcript.received(), &recv_blinder),
|
||||
};
|
||||
|
||||
let sent_hash_secret = PlaintextHashSecret {
|
||||
direction: Direction::Sent,
|
||||
idx: sent_idx,
|
||||
alg: HashAlgId::BLAKE3,
|
||||
blinder: sent_blinder,
|
||||
};
|
||||
|
||||
let recv_hash_secret = PlaintextHashSecret {
|
||||
direction: Direction::Received,
|
||||
idx: recv_idx,
|
||||
alg: HashAlgId::BLAKE3,
|
||||
blinder: recv_blinder,
|
||||
let encoding_commitment = EncodingCommitment {
|
||||
root: encoding_tree.root(),
|
||||
};
|
||||
|
||||
let request_config = RequestConfig::default();
|
||||
@@ -86,14 +74,8 @@ fn test_api() {
|
||||
.handshake_data(server_cert_data)
|
||||
.transcript(transcript)
|
||||
.transcript_commitments(
|
||||
vec![
|
||||
TranscriptSecret::Hash(sent_hash_secret),
|
||||
TranscriptSecret::Hash(recv_hash_secret),
|
||||
],
|
||||
vec![
|
||||
TranscriptCommitment::Hash(sent_hash_commitment.clone()),
|
||||
TranscriptCommitment::Hash(recv_hash_commitment.clone()),
|
||||
],
|
||||
vec![TranscriptSecret::Encoding(encoding_tree)],
|
||||
vec![TranscriptCommitment::Encoding(encoding_commitment.clone())],
|
||||
);
|
||||
|
||||
let (request, secrets) = request_builder.build(&provider).unwrap();
|
||||
@@ -113,10 +95,8 @@ fn test_api() {
|
||||
.connection_info(connection_info.clone())
|
||||
// Server key Notary received during handshake
|
||||
.server_ephemeral_key(server_ephemeral_key)
|
||||
.transcript_commitments(vec![
|
||||
TranscriptCommitment::Hash(sent_hash_commitment),
|
||||
TranscriptCommitment::Hash(recv_hash_commitment),
|
||||
]);
|
||||
.encoder_secret(encoder_secret())
|
||||
.transcript_commitments(vec![TranscriptCommitment::Encoding(encoding_commitment)]);
|
||||
|
||||
let attestation = attestation_builder.build(&provider).unwrap();
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
//! Fixtures for testing
|
||||
|
||||
mod provider;
|
||||
pub mod transcript;
|
||||
|
||||
pub use provider::FixtureEncodingProvider;
|
||||
|
||||
use hex::FromHex;
|
||||
|
||||
use crate::{
|
||||
@@ -10,6 +13,10 @@ use crate::{
|
||||
ServerEphemKey, ServerName, ServerSignature, SignatureAlgorithm, TlsVersion,
|
||||
TranscriptLength,
|
||||
},
|
||||
transcript::{
|
||||
encoding::{EncoderSecret, EncodingProvider},
|
||||
Transcript,
|
||||
},
|
||||
webpki::CertificateDer,
|
||||
};
|
||||
|
||||
@@ -122,3 +129,27 @@ impl ConnectionFixture {
|
||||
server_ephemeral_key
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an encoding provider fixture.
|
||||
pub fn encoding_provider(tx: &[u8], rx: &[u8]) -> impl EncodingProvider {
|
||||
let secret = encoder_secret();
|
||||
FixtureEncodingProvider::new(&secret, Transcript::new(tx, rx))
|
||||
}
|
||||
|
||||
/// Seed fixture.
|
||||
const SEED: [u8; 32] = [0; 32];
|
||||
|
||||
/// Delta fixture.
|
||||
const DELTA: [u8; 16] = [1; 16];
|
||||
|
||||
/// Returns an encoder secret fixture.
|
||||
pub fn encoder_secret() -> EncoderSecret {
|
||||
EncoderSecret::new(SEED, DELTA)
|
||||
}
|
||||
|
||||
/// Returns a tampered encoder secret fixture.
|
||||
pub fn encoder_secret_tampered_seed() -> EncoderSecret {
|
||||
let mut seed = SEED;
|
||||
seed[0] += 1;
|
||||
EncoderSecret::new(seed, DELTA)
|
||||
}
|
||||
|
||||
41
crates/core/src/fixtures/provider.rs
Normal file
41
crates/core/src/fixtures/provider.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::transcript::{
|
||||
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider, EncodingProviderError},
|
||||
Direction, Transcript,
|
||||
};
|
||||
|
||||
/// A encoding provider fixture.
|
||||
pub struct FixtureEncodingProvider {
|
||||
encoder: Box<dyn Encoder>,
|
||||
transcript: Transcript,
|
||||
}
|
||||
|
||||
impl FixtureEncodingProvider {
|
||||
/// Creates a new encoding provider fixture.
|
||||
pub(crate) fn new(secret: &EncoderSecret, transcript: Transcript) -> Self {
|
||||
Self {
|
||||
encoder: Box::new(new_encoder(secret)),
|
||||
transcript,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncodingProvider for FixtureEncodingProvider {
|
||||
fn provide_encoding(
|
||||
&self,
|
||||
direction: Direction,
|
||||
range: Range<usize>,
|
||||
dest: &mut Vec<u8>,
|
||||
) -> Result<(), EncodingProviderError> {
|
||||
let transcript = match direction {
|
||||
Direction::Sent => &self.transcript.sent(),
|
||||
Direction::Received => &self.transcript.received(),
|
||||
};
|
||||
|
||||
let data = transcript.get(range.clone()).ok_or(EncodingProviderError)?;
|
||||
self.encoder.encode_data(direction, range, data, dest);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,9 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
connection::ServerName,
|
||||
transcript::{PartialTranscript, TranscriptCommitment, TranscriptSecret},
|
||||
transcript::{
|
||||
encoding::EncoderSecret, PartialTranscript, TranscriptCommitment, TranscriptSecret,
|
||||
},
|
||||
};
|
||||
|
||||
/// Prover output.
|
||||
@@ -40,6 +42,8 @@ pub struct VerifierOutput {
|
||||
pub server_name: Option<ServerName>,
|
||||
/// Transcript data.
|
||||
pub transcript: Option<PartialTranscript>,
|
||||
/// Encoding commitment secret.
|
||||
pub encoder_secret: Option<EncoderSecret>,
|
||||
/// Transcript commitments.
|
||||
pub transcript_commitments: Vec<TranscriptCommitment>,
|
||||
}
|
||||
|
||||
@@ -63,6 +63,11 @@ impl MerkleProof {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the leaf count of the Merkle tree associated with the proof.
|
||||
pub(crate) fn leaf_count(&self) -> usize {
|
||||
self.leaf_count
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
//! withheld.
|
||||
|
||||
mod commit;
|
||||
pub mod encoding;
|
||||
pub mod hash;
|
||||
mod proof;
|
||||
mod tls;
|
||||
|
||||
@@ -8,15 +8,27 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
hash::HashAlgId,
|
||||
transcript::{
|
||||
encoding::{EncodingCommitment, EncodingTree},
|
||||
hash::{PlaintextHash, PlaintextHashSecret},
|
||||
Direction, RangeSet, Transcript,
|
||||
},
|
||||
};
|
||||
|
||||
/// The maximum allowed total bytelength of committed data for a single
|
||||
/// commitment kind. Used to prevent DoS during verification. (May cause the
|
||||
/// verifier to hash up to a max of 1GB * 128 = 128GB of data for certain kinds
|
||||
/// of encoding commitments.)
|
||||
///
|
||||
/// This value must not exceed bcs's MAX_SEQUENCE_LENGTH limit (which is (1 <<
|
||||
/// 31) - 1 by default)
|
||||
pub(crate) const MAX_TOTAL_COMMITTED_DATA: usize = 1_000_000_000;
|
||||
|
||||
/// Kind of transcript commitment.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum TranscriptCommitmentKind {
|
||||
/// A commitment to encodings of the transcript.
|
||||
Encoding,
|
||||
/// A hash commitment to plaintext in the transcript.
|
||||
Hash {
|
||||
/// The hash algorithm used.
|
||||
@@ -27,6 +39,7 @@ pub enum TranscriptCommitmentKind {
|
||||
impl fmt::Display for TranscriptCommitmentKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Encoding => f.write_str("encoding"),
|
||||
Self::Hash { alg } => write!(f, "hash ({alg})"),
|
||||
}
|
||||
}
|
||||
@@ -36,6 +49,8 @@ impl fmt::Display for TranscriptCommitmentKind {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum TranscriptCommitment {
|
||||
/// Encoding commitment.
|
||||
Encoding(EncodingCommitment),
|
||||
/// Plaintext hash commitment.
|
||||
Hash(PlaintextHash),
|
||||
}
|
||||
@@ -44,6 +59,8 @@ pub enum TranscriptCommitment {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum TranscriptSecret {
|
||||
/// Encoding tree.
|
||||
Encoding(EncodingTree),
|
||||
/// Plaintext hash secret.
|
||||
Hash(PlaintextHashSecret),
|
||||
}
|
||||
@@ -51,6 +68,9 @@ pub enum TranscriptSecret {
|
||||
/// Configuration for transcript commitments.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TranscriptCommitConfig {
|
||||
encoding_hash_alg: HashAlgId,
|
||||
has_encoding: bool,
|
||||
has_hash: bool,
|
||||
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
|
||||
}
|
||||
|
||||
@@ -60,23 +80,53 @@ impl TranscriptCommitConfig {
|
||||
TranscriptCommitConfigBuilder::new(transcript)
|
||||
}
|
||||
|
||||
/// Returns the hash algorithm to use for encoding commitments.
|
||||
pub fn encoding_hash_alg(&self) -> &HashAlgId {
|
||||
&self.encoding_hash_alg
|
||||
}
|
||||
|
||||
/// Returns `true` if the configuration has any encoding commitments.
|
||||
pub fn has_encoding(&self) -> bool {
|
||||
self.has_encoding
|
||||
}
|
||||
|
||||
/// Returns `true` if the configuration has any hash commitments.
|
||||
pub fn has_hash(&self) -> bool {
|
||||
self.commits
|
||||
.iter()
|
||||
.any(|(_, kind)| matches!(kind, TranscriptCommitmentKind::Hash { .. }))
|
||||
self.has_hash
|
||||
}
|
||||
|
||||
/// Returns an iterator over the encoding commitment indices.
|
||||
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
|
||||
self.commits.iter().filter_map(|(idx, kind)| match kind {
|
||||
TranscriptCommitmentKind::Encoding => Some(idx),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over the hash commitment indices.
|
||||
pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, RangeSet<usize>), &HashAlgId)> {
|
||||
self.commits.iter().map(|(idx, kind)| match kind {
|
||||
TranscriptCommitmentKind::Hash { alg } => (idx, alg),
|
||||
self.commits.iter().filter_map(|(idx, kind)| match kind {
|
||||
TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a request for the transcript commitments.
|
||||
pub fn to_request(&self) -> TranscriptCommitRequest {
|
||||
TranscriptCommitRequest {
|
||||
encoding: self.has_encoding.then(|| {
|
||||
let mut sent = RangeSet::default();
|
||||
let mut recv = RangeSet::default();
|
||||
|
||||
for (dir, idx) in self.iter_encoding() {
|
||||
match dir {
|
||||
Direction::Sent => sent.union_mut(idx),
|
||||
Direction::Received => recv.union_mut(idx),
|
||||
}
|
||||
}
|
||||
|
||||
(sent, recv)
|
||||
}),
|
||||
hash: self
|
||||
.iter_hash()
|
||||
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
|
||||
@@ -86,9 +136,15 @@ impl TranscriptCommitConfig {
|
||||
}
|
||||
|
||||
/// A builder for [`TranscriptCommitConfig`].
|
||||
///
|
||||
/// The default hash algorithm is [`HashAlgId::BLAKE3`] and the default kind
|
||||
/// is [`TranscriptCommitmentKind::Encoding`].
|
||||
#[derive(Debug)]
|
||||
pub struct TranscriptCommitConfigBuilder<'a> {
|
||||
transcript: &'a Transcript,
|
||||
encoding_hash_alg: HashAlgId,
|
||||
has_encoding: bool,
|
||||
has_hash: bool,
|
||||
default_kind: TranscriptCommitmentKind,
|
||||
commits: HashSet<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
|
||||
}
|
||||
@@ -98,13 +154,20 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
|
||||
pub fn new(transcript: &'a Transcript) -> Self {
|
||||
Self {
|
||||
transcript,
|
||||
default_kind: TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::BLAKE3,
|
||||
},
|
||||
encoding_hash_alg: HashAlgId::BLAKE3,
|
||||
has_encoding: false,
|
||||
has_hash: false,
|
||||
default_kind: TranscriptCommitmentKind::Encoding,
|
||||
commits: HashSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the hash algorithm to use for encoding commitments.
|
||||
pub fn encoding_hash_alg(&mut self, alg: HashAlgId) -> &mut Self {
|
||||
self.encoding_hash_alg = alg;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the default kind of commitment to use.
|
||||
pub fn default_kind(&mut self, default_kind: TranscriptCommitmentKind) -> &mut Self {
|
||||
self.default_kind = default_kind;
|
||||
@@ -138,6 +201,11 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
|
||||
));
|
||||
}
|
||||
|
||||
match kind {
|
||||
TranscriptCommitmentKind::Encoding => self.has_encoding = true,
|
||||
TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
|
||||
}
|
||||
|
||||
self.commits.insert(((direction, idx), kind));
|
||||
|
||||
Ok(self)
|
||||
@@ -184,6 +252,9 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
|
||||
Ok(TranscriptCommitConfig {
|
||||
encoding_hash_alg: self.encoding_hash_alg,
|
||||
has_encoding: self.has_encoding,
|
||||
has_hash: self.has_hash,
|
||||
commits: Vec::from_iter(self.commits),
|
||||
})
|
||||
}
|
||||
@@ -230,10 +301,16 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
|
||||
/// Request to compute transcript commitments.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TranscriptCommitRequest {
|
||||
encoding: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
|
||||
}
|
||||
|
||||
impl TranscriptCommitRequest {
|
||||
/// Returns `true` if an encoding commitment is requested.
|
||||
pub fn has_encoding(&self) -> bool {
|
||||
self.encoding.is_some()
|
||||
}
|
||||
|
||||
/// Returns `true` if a hash commitment is requested.
|
||||
pub fn has_hash(&self) -> bool {
|
||||
!self.hash.is_empty()
|
||||
@@ -243,6 +320,11 @@ impl TranscriptCommitRequest {
|
||||
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
|
||||
self.hash.iter()
|
||||
}
|
||||
|
||||
/// Returns the ranges of the encoding commitments.
|
||||
pub fn encoding(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
|
||||
self.encoding.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
22
crates/core/src/transcript/encoding.rs
Normal file
22
crates/core/src/transcript/encoding.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! Transcript encoding commitments and proofs.
|
||||
|
||||
mod encoder;
|
||||
mod proof;
|
||||
mod provider;
|
||||
mod tree;
|
||||
|
||||
pub use encoder::{new_encoder, Encoder, EncoderSecret};
|
||||
pub use proof::{EncodingProof, EncodingProofError};
|
||||
pub use provider::{EncodingProvider, EncodingProviderError};
|
||||
pub use tree::{EncodingTree, EncodingTreeError};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::hash::TypedHash;
|
||||
|
||||
/// Transcript encoding commitment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct EncodingCommitment {
|
||||
/// Merkle root of the encoding commitments.
|
||||
pub root: TypedHash,
|
||||
}
|
||||
137
crates/core/src/transcript/encoding/encoder.rs
Normal file
137
crates/core/src/transcript/encoding/encoder.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::transcript::Direction;
|
||||
use itybity::ToBits;
|
||||
use rand::{RngCore, SeedableRng};
|
||||
use rand_chacha::ChaCha12Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The size of the encoding for 1 bit, in bytes.
|
||||
const BIT_ENCODING_SIZE: usize = 16;
|
||||
/// The size of the encoding for 1 byte, in bytes.
|
||||
const BYTE_ENCODING_SIZE: usize = 128;
|
||||
|
||||
/// Secret used by an encoder to generate encodings.
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct EncoderSecret {
|
||||
seed: [u8; 32],
|
||||
delta: [u8; BIT_ENCODING_SIZE],
|
||||
}
|
||||
|
||||
opaque_debug::implement!(EncoderSecret);
|
||||
|
||||
impl EncoderSecret {
|
||||
/// Creates a new secret.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `seed` - The seed for the PRG.
|
||||
/// * `delta` - Delta for deriving the one-encodings.
|
||||
pub fn new(seed: [u8; 32], delta: [u8; 16]) -> Self {
|
||||
Self { seed, delta }
|
||||
}
|
||||
|
||||
/// Returns the seed.
|
||||
pub fn seed(&self) -> &[u8; 32] {
|
||||
&self.seed
|
||||
}
|
||||
|
||||
/// Returns the delta.
|
||||
pub fn delta(&self) -> &[u8; 16] {
|
||||
&self.delta
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new encoder.
|
||||
pub fn new_encoder(secret: &EncoderSecret) -> impl Encoder {
|
||||
ChaChaEncoder::new(secret)
|
||||
}
|
||||
|
||||
pub(crate) struct ChaChaEncoder {
|
||||
seed: [u8; 32],
|
||||
delta: [u8; 16],
|
||||
}
|
||||
|
||||
impl ChaChaEncoder {
|
||||
pub(crate) fn new(secret: &EncoderSecret) -> Self {
|
||||
let seed = *secret.seed();
|
||||
let delta = *secret.delta();
|
||||
|
||||
Self { seed, delta }
|
||||
}
|
||||
|
||||
pub(crate) fn new_prg(&self, stream_id: u64) -> ChaCha12Rng {
|
||||
let mut prg = ChaCha12Rng::from_seed(self.seed);
|
||||
prg.set_stream(stream_id);
|
||||
prg.set_word_pos(0);
|
||||
prg
|
||||
}
|
||||
}
|
||||
|
||||
/// A transcript encoder.
|
||||
///
|
||||
/// This is an internal implementation detail that should not be exposed to the
|
||||
/// public API.
|
||||
pub trait Encoder {
|
||||
/// Writes the zero encoding for the given range of the transcript into the
|
||||
/// destination buffer.
|
||||
fn encode_range(&self, direction: Direction, range: Range<usize>, dest: &mut Vec<u8>);
|
||||
|
||||
/// Writes the encoding for the given data into the destination buffer.
|
||||
fn encode_data(
|
||||
&self,
|
||||
direction: Direction,
|
||||
range: Range<usize>,
|
||||
data: &[u8],
|
||||
dest: &mut Vec<u8>,
|
||||
);
|
||||
}
|
||||
|
||||
impl Encoder for ChaChaEncoder {
|
||||
fn encode_range(&self, direction: Direction, range: Range<usize>, dest: &mut Vec<u8>) {
|
||||
// ChaCha encoder works with 32-bit words. Each encoded bit is 128 bits long.
|
||||
const WORDS_PER_BYTE: u128 = 8 * 128 / 32;
|
||||
|
||||
let stream_id: u64 = match direction {
|
||||
Direction::Sent => 0,
|
||||
Direction::Received => 1,
|
||||
};
|
||||
|
||||
let mut prg = self.new_prg(stream_id);
|
||||
let len = range.len() * BYTE_ENCODING_SIZE;
|
||||
let pos = dest.len();
|
||||
|
||||
// Write 0s to the destination buffer.
|
||||
dest.resize(pos + len, 0);
|
||||
|
||||
// Fill the destination buffer with the PRG.
|
||||
prg.set_word_pos(range.start as u128 * WORDS_PER_BYTE);
|
||||
prg.fill_bytes(&mut dest[pos..pos + len]);
|
||||
}
|
||||
|
||||
fn encode_data(
|
||||
&self,
|
||||
direction: Direction,
|
||||
range: Range<usize>,
|
||||
data: &[u8],
|
||||
dest: &mut Vec<u8>,
|
||||
) {
|
||||
const ZERO: [u8; 16] = [0; BIT_ENCODING_SIZE];
|
||||
|
||||
let pos = dest.len();
|
||||
|
||||
// Write the zero encoding for the given range.
|
||||
self.encode_range(direction, range, dest);
|
||||
let dest = &mut dest[pos..];
|
||||
|
||||
for (pos, bit) in data.iter_lsb0().enumerate() {
|
||||
// Add the delta to the encoding whenever the encoded bit is 1,
|
||||
// otherwise add a zero.
|
||||
let summand = if bit { &self.delta } else { &ZERO };
|
||||
dest[pos * BIT_ENCODING_SIZE..(pos + 1) * BIT_ENCODING_SIZE]
|
||||
.iter_mut()
|
||||
.zip(summand)
|
||||
.for_each(|(a, b)| *a ^= *b);
|
||||
}
|
||||
}
|
||||
}
|
||||
361
crates/core/src/transcript/encoding/proof.rs
Normal file
361
crates/core/src/transcript/encoding/proof.rs
Normal file
@@ -0,0 +1,361 @@
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
hash::{Blinder, HashProvider, HashProviderError},
|
||||
merkle::{MerkleError, MerkleProof},
|
||||
transcript::{
|
||||
commit::MAX_TOTAL_COMMITTED_DATA,
|
||||
encoding::{new_encoder, Encoder, EncoderSecret, EncodingCommitment},
|
||||
Direction,
|
||||
},
|
||||
};
|
||||
|
||||
/// An opening of a leaf in the encoding tree.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub(super) struct Opening {
|
||||
pub(super) direction: Direction,
|
||||
pub(super) idx: RangeSet<usize>,
|
||||
pub(super) blinder: Blinder,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Opening);
|
||||
|
||||
/// An encoding commitment proof.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(try_from = "validation::EncodingProofUnchecked")]
|
||||
pub struct EncodingProof {
|
||||
/// The proof of inclusion of the commitment(s) in the Merkle tree of
|
||||
/// commitments.
|
||||
pub(super) inclusion_proof: MerkleProof,
|
||||
pub(super) openings: HashMap<usize, Opening>,
|
||||
}
|
||||
|
||||
impl EncodingProof {
|
||||
/// Verifies the proof against the commitment.
|
||||
///
|
||||
/// Returns the authenticated indices of the sent and received data,
|
||||
/// respectively.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `provider` - Hash provider.
|
||||
/// * `commitment` - Encoding commitment to verify against.
|
||||
/// * `sent` - Sent data to authenticate.
|
||||
/// * `recv` - Received data to authenticate.
|
||||
pub fn verify_with_provider(
|
||||
&self,
|
||||
provider: &HashProvider,
|
||||
secret: &EncoderSecret,
|
||||
commitment: &EncodingCommitment,
|
||||
sent: &[u8],
|
||||
recv: &[u8],
|
||||
) -> Result<(RangeSet<usize>, RangeSet<usize>), EncodingProofError> {
|
||||
let hasher = provider.get(&commitment.root.alg)?;
|
||||
|
||||
let encoder = new_encoder(secret);
|
||||
let Self {
|
||||
inclusion_proof,
|
||||
openings,
|
||||
} = self;
|
||||
|
||||
let mut leaves = Vec::with_capacity(openings.len());
|
||||
let mut expected_leaf = Vec::default();
|
||||
let mut total_opened = 0u128;
|
||||
let mut auth_sent = RangeSet::default();
|
||||
let mut auth_recv = RangeSet::default();
|
||||
for (
|
||||
id,
|
||||
Opening {
|
||||
direction,
|
||||
idx,
|
||||
blinder,
|
||||
},
|
||||
) in openings
|
||||
{
|
||||
// Make sure the amount of data being proved is bounded.
|
||||
total_opened += idx.len() as u128;
|
||||
if total_opened > MAX_TOTAL_COMMITTED_DATA as u128 {
|
||||
return Err(EncodingProofError::new(
|
||||
ErrorKind::Proof,
|
||||
"exceeded maximum allowed data",
|
||||
))?;
|
||||
}
|
||||
|
||||
let (data, auth) = match direction {
|
||||
Direction::Sent => (sent, &mut auth_sent),
|
||||
Direction::Received => (recv, &mut auth_recv),
|
||||
};
|
||||
|
||||
// Make sure the ranges are within the bounds of the transcript.
|
||||
if idx.end().unwrap_or(0) > data.len() {
|
||||
return Err(EncodingProofError::new(
|
||||
ErrorKind::Proof,
|
||||
format!(
|
||||
"index out of bounds of the transcript ({}): {} > {}",
|
||||
direction,
|
||||
idx.end().unwrap_or(0),
|
||||
data.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
expected_leaf.clear();
|
||||
for range in idx.iter() {
|
||||
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
|
||||
}
|
||||
expected_leaf.extend_from_slice(blinder.as_bytes());
|
||||
|
||||
// Compute the expected hash of the commitment to make sure it is
|
||||
// present in the merkle tree.
|
||||
leaves.push((*id, hasher.hash(&expected_leaf)));
|
||||
|
||||
auth.union_mut(idx);
|
||||
}
|
||||
|
||||
// Verify that the expected hashes are present in the merkle tree.
|
||||
//
|
||||
// This proves the Prover committed to the purported data prior to the encoder
|
||||
// seed being revealed. Ergo, if the encodings are authentic then the purported
|
||||
// data is authentic.
|
||||
inclusion_proof.verify(hasher, &commitment.root, leaves)?;
|
||||
|
||||
Ok((auth_sent, auth_recv))
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`EncodingProof`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct EncodingProofError {
|
||||
kind: ErrorKind,
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl EncodingProofError {
|
||||
fn new<E>(kind: ErrorKind, source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
Self {
|
||||
kind,
|
||||
source: Some(source.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ErrorKind {
|
||||
Provider,
|
||||
Proof,
|
||||
}
|
||||
|
||||
impl fmt::Display for EncodingProofError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("encoding proof error: ")?;
|
||||
|
||||
match self.kind {
|
||||
ErrorKind::Provider => f.write_str("provider error")?,
|
||||
ErrorKind::Proof => f.write_str("proof error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
write!(f, " caused by: {source}")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HashProviderError> for EncodingProofError {
|
||||
fn from(error: HashProviderError) -> Self {
|
||||
Self::new(ErrorKind::Provider, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MerkleError> for EncodingProofError {
|
||||
fn from(error: MerkleError) -> Self {
|
||||
Self::new(ErrorKind::Proof, error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Invalid encoding proof error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("invalid encoding proof: {0}")]
|
||||
pub struct InvalidEncodingProof(&'static str);
|
||||
|
||||
mod validation {
|
||||
use super::*;
|
||||
|
||||
/// The maximum allowed height of the Merkle tree of encoding commitments.
|
||||
///
|
||||
/// The statistical security parameter (SSP) of the encoding commitment
|
||||
/// protocol is calculated as "the number of uniformly random bits in a
|
||||
/// single bit's encoding minus `MAX_HEIGHT`".
|
||||
///
|
||||
/// For example, a bit encoding used in garbled circuits typically has 127
|
||||
/// uniformly random bits, hence when using it in the encoding
|
||||
/// commitment protocol, the SSP is 127 - 30 = 97 bits.
|
||||
///
|
||||
/// Leaving this validation here as a fail-safe in case we ever start
|
||||
/// using shorter encodings.
|
||||
const MAX_HEIGHT: usize = 30;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(super) struct EncodingProofUnchecked {
|
||||
inclusion_proof: MerkleProof,
|
||||
openings: HashMap<usize, Opening>,
|
||||
}
|
||||
|
||||
impl TryFrom<EncodingProofUnchecked> for EncodingProof {
|
||||
type Error = InvalidEncodingProof;
|
||||
|
||||
fn try_from(unchecked: EncodingProofUnchecked) -> Result<Self, Self::Error> {
|
||||
if unchecked.inclusion_proof.leaf_count() > 1 << MAX_HEIGHT {
|
||||
return Err(InvalidEncodingProof(
|
||||
"the height of the tree exceeds the maximum allowed",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
inclusion_proof: unchecked.inclusion_proof,
|
||||
openings: unchecked.openings,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
|
||||
|
||||
use crate::{
|
||||
fixtures::{encoder_secret, encoder_secret_tampered_seed, encoding_provider},
|
||||
hash::Blake3,
|
||||
transcript::{encoding::EncodingTree, Transcript},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
struct EncodingFixture {
|
||||
transcript: Transcript,
|
||||
proof: EncodingProof,
|
||||
commitment: EncodingCommitment,
|
||||
}
|
||||
|
||||
fn new_encoding_fixture() -> EncodingFixture {
|
||||
let transcript = Transcript::new(POST_JSON, OK_JSON);
|
||||
|
||||
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
|
||||
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
|
||||
|
||||
let provider = encoding_provider(transcript.sent(), transcript.received());
|
||||
let tree = EncodingTree::new(&Blake3::default(), [&idx_0, &idx_1], &provider).unwrap();
|
||||
|
||||
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
|
||||
|
||||
let commitment = EncodingCommitment { root: tree.root() };
|
||||
|
||||
EncodingFixture {
|
||||
transcript,
|
||||
proof,
|
||||
commitment,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_encoding_proof_tampered_seed() {
|
||||
let EncodingFixture {
|
||||
transcript,
|
||||
proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture();
|
||||
|
||||
let err = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
&encoder_secret_tampered_seed(),
|
||||
&commitment,
|
||||
transcript.sent(),
|
||||
transcript.received(),
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(err.kind, ErrorKind::Proof));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_encoding_proof_out_of_range() {
|
||||
let EncodingFixture {
|
||||
transcript,
|
||||
proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture();
|
||||
|
||||
let sent = &transcript.sent()[transcript.sent().len() - 1..];
|
||||
let recv = &transcript.received()[transcript.received().len() - 2..];
|
||||
|
||||
let err = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
&encoder_secret(),
|
||||
&commitment,
|
||||
sent,
|
||||
recv,
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(err.kind, ErrorKind::Proof));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_encoding_proof_tampered_idx() {
|
||||
let EncodingFixture {
|
||||
transcript,
|
||||
mut proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture();
|
||||
|
||||
let Opening { idx, .. } = proof.openings.values_mut().next().unwrap();
|
||||
|
||||
*idx = RangeSet::from([0..3, 13..15]);
|
||||
|
||||
let err = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
&encoder_secret(),
|
||||
&commitment,
|
||||
transcript.sent(),
|
||||
transcript.received(),
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(err.kind, ErrorKind::Proof));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_encoding_proof_tampered_encoding_blinder() {
|
||||
let EncodingFixture {
|
||||
transcript,
|
||||
mut proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture();
|
||||
|
||||
let Opening { blinder, .. } = proof.openings.values_mut().next().unwrap();
|
||||
|
||||
*blinder = rand::random();
|
||||
|
||||
let err = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
&encoder_secret(),
|
||||
&commitment,
|
||||
transcript.sent(),
|
||||
transcript.received(),
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(err.kind, ErrorKind::Proof));
|
||||
}
|
||||
}
|
||||
19
crates/core/src/transcript/encoding/provider.rs
Normal file
19
crates/core/src/transcript/encoding/provider.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::transcript::Direction;
|
||||
|
||||
/// A provider of plaintext encodings.
|
||||
pub trait EncodingProvider {
|
||||
/// Writes the encoding of the given range into the destination buffer.
|
||||
fn provide_encoding(
|
||||
&self,
|
||||
direction: Direction,
|
||||
range: Range<usize>,
|
||||
dest: &mut Vec<u8>,
|
||||
) -> Result<(), EncodingProviderError>;
|
||||
}
|
||||
|
||||
/// Error for [`EncodingProvider`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("failed to provide encoding")]
|
||||
pub struct EncodingProviderError;
|
||||
327
crates/core/src/transcript/encoding/tree.rs
Normal file
327
crates/core/src/transcript/encoding/tree.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use bimap::BiMap;
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
|
||||
merkle::MerkleTree,
|
||||
transcript::{
|
||||
encoding::{
|
||||
proof::{EncodingProof, Opening},
|
||||
EncodingProvider,
|
||||
},
|
||||
Direction,
|
||||
},
|
||||
};
|
||||
|
||||
/// Encoding tree builder error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EncodingTreeError {
|
||||
/// Index is out of bounds of the transcript.
|
||||
#[error("index is out of bounds of the transcript")]
|
||||
OutOfBounds {
|
||||
/// The index.
|
||||
index: RangeSet<usize>,
|
||||
/// The transcript length.
|
||||
transcript_length: usize,
|
||||
},
|
||||
/// Encoding provider is missing an encoding for an index.
|
||||
#[error("encoding provider is missing an encoding for an index")]
|
||||
MissingEncoding {
|
||||
/// The index which is missing.
|
||||
index: RangeSet<usize>,
|
||||
},
|
||||
/// Index is missing from the tree.
|
||||
#[error("index is missing from the tree")]
|
||||
MissingLeaf {
|
||||
/// The index which is missing.
|
||||
index: RangeSet<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
/// A merkle tree of transcript encodings.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct EncodingTree {
|
||||
/// Merkle tree of the commitments.
|
||||
tree: MerkleTree,
|
||||
/// Nonces used to blind the hashes.
|
||||
blinders: Vec<Blinder>,
|
||||
/// Mapping between the index of a leaf and the transcript index it
|
||||
/// corresponds to.
|
||||
idxs: BiMap<usize, (Direction, RangeSet<usize>)>,
|
||||
/// Union of all transcript indices in the sent direction.
|
||||
sent_idx: RangeSet<usize>,
|
||||
/// Union of all transcript indices in the received direction.
|
||||
received_idx: RangeSet<usize>,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(EncodingTree);
|
||||
|
||||
impl EncodingTree {
|
||||
/// Creates a new encoding tree.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `hasher` - The hash algorithm to use.
|
||||
/// * `idxs` - The subsequence indices to commit to.
|
||||
/// * `provider` - The encoding provider.
|
||||
pub fn new<'idx>(
|
||||
hasher: &dyn HashAlgorithm,
|
||||
idxs: impl IntoIterator<Item = &'idx (Direction, RangeSet<usize>)>,
|
||||
provider: &dyn EncodingProvider,
|
||||
) -> Result<Self, EncodingTreeError> {
|
||||
let mut this = Self {
|
||||
tree: MerkleTree::new(hasher.id()),
|
||||
blinders: Vec::new(),
|
||||
idxs: BiMap::new(),
|
||||
sent_idx: RangeSet::default(),
|
||||
received_idx: RangeSet::default(),
|
||||
};
|
||||
|
||||
let mut leaves = Vec::new();
|
||||
let mut encoding = Vec::new();
|
||||
for dir_idx in idxs {
|
||||
let direction = dir_idx.0;
|
||||
let idx = &dir_idx.1;
|
||||
|
||||
// Ignore empty indices.
|
||||
if idx.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if this.idxs.contains_right(dir_idx) {
|
||||
// The subsequence is already in the tree.
|
||||
continue;
|
||||
}
|
||||
|
||||
let blinder: Blinder = rand::random();
|
||||
|
||||
encoding.clear();
|
||||
for range in idx.iter() {
|
||||
provider
|
||||
.provide_encoding(direction, range, &mut encoding)
|
||||
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
|
||||
}
|
||||
encoding.extend_from_slice(blinder.as_bytes());
|
||||
|
||||
let leaf = hasher.hash(&encoding);
|
||||
|
||||
leaves.push(leaf);
|
||||
this.blinders.push(blinder);
|
||||
this.idxs.insert(this.idxs.len(), dir_idx.clone());
|
||||
match direction {
|
||||
Direction::Sent => this.sent_idx.union_mut(idx),
|
||||
Direction::Received => this.received_idx.union_mut(idx),
|
||||
}
|
||||
}
|
||||
|
||||
this.tree.insert(hasher, leaves);
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
/// Returns the root of the tree.
|
||||
pub fn root(&self) -> TypedHash {
|
||||
self.tree.root()
|
||||
}
|
||||
|
||||
/// Returns the hash algorithm of the tree.
|
||||
pub fn algorithm(&self) -> HashAlgId {
|
||||
self.tree.algorithm()
|
||||
}
|
||||
|
||||
/// Generates a proof for the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `idxs` - The transcript indices to prove.
|
||||
pub fn proof<'idx>(
|
||||
&self,
|
||||
idxs: impl Iterator<Item = &'idx (Direction, RangeSet<usize>)>,
|
||||
) -> Result<EncodingProof, EncodingTreeError> {
|
||||
let mut openings = HashMap::new();
|
||||
for dir_idx in idxs {
|
||||
let direction = dir_idx.0;
|
||||
let idx = &dir_idx.1;
|
||||
|
||||
let leaf_idx = *self
|
||||
.idxs
|
||||
.get_by_right(dir_idx)
|
||||
.ok_or_else(|| EncodingTreeError::MissingLeaf { index: idx.clone() })?;
|
||||
let blinder = self.blinders[leaf_idx].clone();
|
||||
|
||||
openings.insert(
|
||||
leaf_idx,
|
||||
Opening {
|
||||
direction,
|
||||
idx: idx.clone(),
|
||||
blinder,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let mut indices = openings.keys().copied().collect::<Vec<_>>();
|
||||
indices.sort();
|
||||
|
||||
Ok(EncodingProof {
|
||||
inclusion_proof: self.tree.proof(&indices),
|
||||
openings,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns whether the tree contains the given transcript index.
|
||||
pub fn contains(&self, idx: &(Direction, RangeSet<usize>)) -> bool {
|
||||
self.idxs.contains_right(idx)
|
||||
}
|
||||
|
||||
pub(crate) fn idx(&self, direction: Direction) -> &RangeSet<usize> {
|
||||
match direction {
|
||||
Direction::Sent => &self.sent_idx,
|
||||
Direction::Received => &self.received_idx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the committed transcript indices.
|
||||
pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
|
||||
self.idxs.right_values()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
fixtures::{encoder_secret, encoding_provider},
|
||||
hash::{Blake3, HashProvider},
|
||||
transcript::{encoding::EncodingCommitment, Transcript},
|
||||
};
|
||||
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
|
||||
|
||||
fn new_tree<'seq>(
|
||||
transcript: &Transcript,
|
||||
idxs: impl Iterator<Item = &'seq (Direction, RangeSet<usize>)>,
|
||||
) -> Result<EncodingTree, EncodingTreeError> {
|
||||
let provider = encoding_provider(transcript.sent(), transcript.received());
|
||||
|
||||
EncodingTree::new(&Blake3::default(), idxs, &provider)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoding_tree() {
|
||||
let transcript = Transcript::new(POST_JSON, OK_JSON);
|
||||
|
||||
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
|
||||
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
|
||||
|
||||
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
|
||||
|
||||
assert!(tree.contains(&idx_0));
|
||||
assert!(tree.contains(&idx_1));
|
||||
|
||||
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
|
||||
|
||||
let commitment = EncodingCommitment { root: tree.root() };
|
||||
|
||||
let (auth_sent, auth_recv) = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
&encoder_secret(),
|
||||
&commitment,
|
||||
transcript.sent(),
|
||||
transcript.received(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth_sent, idx_0.1);
|
||||
assert_eq!(auth_recv, idx_1.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoding_tree_multiple_ranges() {
|
||||
let transcript = Transcript::new(POST_JSON, OK_JSON);
|
||||
|
||||
let idx_0 = (Direction::Sent, RangeSet::from(0..1));
|
||||
let idx_1 = (Direction::Sent, RangeSet::from(1..POST_JSON.len()));
|
||||
let idx_2 = (Direction::Received, RangeSet::from(0..1));
|
||||
let idx_3 = (Direction::Received, RangeSet::from(1..OK_JSON.len()));
|
||||
|
||||
let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
|
||||
|
||||
assert!(tree.contains(&idx_0));
|
||||
assert!(tree.contains(&idx_1));
|
||||
assert!(tree.contains(&idx_2));
|
||||
assert!(tree.contains(&idx_3));
|
||||
|
||||
let proof = tree
|
||||
.proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
|
||||
.unwrap();
|
||||
|
||||
let commitment = EncodingCommitment { root: tree.root() };
|
||||
|
||||
let (auth_sent, auth_recv) = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
&encoder_secret(),
|
||||
&commitment,
|
||||
transcript.sent(),
|
||||
transcript.received(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut expected_auth_sent = RangeSet::default();
|
||||
expected_auth_sent.union_mut(&idx_0.1);
|
||||
expected_auth_sent.union_mut(&idx_1.1);
|
||||
|
||||
let mut expected_auth_recv = RangeSet::default();
|
||||
expected_auth_recv.union_mut(&idx_2.1);
|
||||
expected_auth_recv.union_mut(&idx_3.1);
|
||||
|
||||
assert_eq!(auth_sent, expected_auth_sent);
|
||||
assert_eq!(auth_recv, expected_auth_recv);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoding_tree_proof_missing_leaf() {
|
||||
let transcript = Transcript::new(POST_JSON, OK_JSON);
|
||||
|
||||
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
|
||||
let idx_1 = (Direction::Received, RangeSet::from(0..4));
|
||||
let idx_2 = (Direction::Received, RangeSet::from(4..OK_JSON.len()));
|
||||
|
||||
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
|
||||
|
||||
let result = tree
|
||||
.proof([&idx_0, &idx_1, &idx_2].into_iter())
|
||||
.unwrap_err();
|
||||
assert!(matches!(result, EncodingTreeError::MissingLeaf { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoding_tree_out_of_bounds() {
|
||||
let transcript = Transcript::new(POST_JSON, OK_JSON);
|
||||
|
||||
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len() + 1));
|
||||
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len() + 1));
|
||||
|
||||
let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
|
||||
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
|
||||
|
||||
let result = new_tree(&transcript, [&idx_1].into_iter()).unwrap_err();
|
||||
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoding_tree_missing_encoding() {
|
||||
let provider = encoding_provider(&[], &[]);
|
||||
|
||||
let result = EncodingTree::new(
|
||||
&Blake3::default(),
|
||||
[(Direction::Sent, RangeSet::from(0..8))].iter(),
|
||||
&provider,
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ use crate::{
|
||||
hash::{HashAlgId, HashProvider},
|
||||
transcript::{
|
||||
commit::{TranscriptCommitment, TranscriptCommitmentKind},
|
||||
encoding::{EncoderSecret, EncodingProof, EncodingProofError, EncodingTree},
|
||||
hash::{hash_plaintext, PlaintextHash, PlaintextHashSecret},
|
||||
Direction, PartialTranscript, RangeSet, Transcript, TranscriptSecret,
|
||||
},
|
||||
@@ -31,12 +32,14 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::KECCAK256,
|
||||
},
|
||||
TranscriptCommitmentKind::Encoding,
|
||||
];
|
||||
|
||||
/// Proof of the contents of a transcript.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct TranscriptProof {
|
||||
transcript: PartialTranscript,
|
||||
encoding_proof: Option<EncodingProof>,
|
||||
hash_secrets: Vec<PlaintextHashSecret>,
|
||||
}
|
||||
|
||||
@@ -50,18 +53,27 @@ impl TranscriptProof {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `provider` - The hash provider to use for verification.
|
||||
/// * `length` - The transcript length.
|
||||
/// * `commitments` - The commitments to verify against.
|
||||
/// * `attestation_body` - The attestation body to verify against.
|
||||
pub fn verify_with_provider<'a>(
|
||||
self,
|
||||
provider: &HashProvider,
|
||||
length: &TranscriptLength,
|
||||
encoder_secret: Option<&EncoderSecret>,
|
||||
commitments: impl IntoIterator<Item = &'a TranscriptCommitment>,
|
||||
) -> Result<PartialTranscript, TranscriptProofError> {
|
||||
let mut encoding_commitment = None;
|
||||
let mut hash_commitments = HashSet::new();
|
||||
// Index commitments.
|
||||
for commitment in commitments {
|
||||
match commitment {
|
||||
TranscriptCommitment::Encoding(commitment) => {
|
||||
if encoding_commitment.replace(commitment).is_some() {
|
||||
return Err(TranscriptProofError::new(
|
||||
ErrorKind::Encoding,
|
||||
"multiple encoding commitments are present.",
|
||||
));
|
||||
}
|
||||
}
|
||||
TranscriptCommitment::Hash(plaintext_hash) => {
|
||||
hash_commitments.insert(plaintext_hash);
|
||||
}
|
||||
@@ -80,6 +92,34 @@ impl TranscriptProof {
|
||||
let mut total_auth_sent = RangeSet::default();
|
||||
let mut total_auth_recv = RangeSet::default();
|
||||
|
||||
// Verify encoding proof.
|
||||
if let Some(proof) = self.encoding_proof {
|
||||
let secret = encoder_secret.ok_or_else(|| {
|
||||
TranscriptProofError::new(
|
||||
ErrorKind::Encoding,
|
||||
"contains an encoding proof but missing encoder secret",
|
||||
)
|
||||
})?;
|
||||
|
||||
let commitment = encoding_commitment.ok_or_else(|| {
|
||||
TranscriptProofError::new(
|
||||
ErrorKind::Encoding,
|
||||
"contains an encoding proof but missing encoding commitment",
|
||||
)
|
||||
})?;
|
||||
|
||||
let (auth_sent, auth_recv) = proof.verify_with_provider(
|
||||
provider,
|
||||
secret,
|
||||
commitment,
|
||||
self.transcript.sent_unsafe(),
|
||||
self.transcript.received_unsafe(),
|
||||
)?;
|
||||
|
||||
total_auth_sent.union_mut(&auth_sent);
|
||||
total_auth_recv.union_mut(&auth_recv);
|
||||
}
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
for PlaintextHashSecret {
|
||||
direction,
|
||||
@@ -163,6 +203,7 @@ impl TranscriptProofError {
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ErrorKind {
|
||||
Encoding,
|
||||
Hash,
|
||||
Proof,
|
||||
}
|
||||
@@ -172,6 +213,7 @@ impl fmt::Display for TranscriptProofError {
|
||||
f.write_str("transcript proof error: ")?;
|
||||
|
||||
match self.kind {
|
||||
ErrorKind::Encoding => f.write_str("encoding error")?,
|
||||
ErrorKind::Hash => f.write_str("hash error")?,
|
||||
ErrorKind::Proof => f.write_str("proof error")?,
|
||||
}
|
||||
@@ -184,6 +226,12 @@ impl fmt::Display for TranscriptProofError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodingProofError> for TranscriptProofError {
|
||||
fn from(e: EncodingProofError) -> Self {
|
||||
TranscriptProofError::new(ErrorKind::Encoding, e)
|
||||
}
|
||||
}
|
||||
|
||||
/// Union of ranges to reveal.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
struct QueryIdx {
|
||||
@@ -228,6 +276,7 @@ pub struct TranscriptProofBuilder<'a> {
|
||||
/// Commitment kinds in order of preference for building transcript proofs.
|
||||
commitment_kinds: Vec<TranscriptCommitmentKind>,
|
||||
transcript: &'a Transcript,
|
||||
encoding_tree: Option<&'a EncodingTree>,
|
||||
hash_secrets: Vec<&'a PlaintextHashSecret>,
|
||||
committed_sent: RangeSet<usize>,
|
||||
committed_recv: RangeSet<usize>,
|
||||
@@ -243,9 +292,15 @@ impl<'a> TranscriptProofBuilder<'a> {
|
||||
let mut committed_sent = RangeSet::default();
|
||||
let mut committed_recv = RangeSet::default();
|
||||
|
||||
let mut encoding_tree = None;
|
||||
let mut hash_secrets = Vec::new();
|
||||
for secret in secrets {
|
||||
match secret {
|
||||
TranscriptSecret::Encoding(tree) => {
|
||||
committed_sent.union_mut(tree.idx(Direction::Sent));
|
||||
committed_recv.union_mut(tree.idx(Direction::Received));
|
||||
encoding_tree = Some(tree);
|
||||
}
|
||||
TranscriptSecret::Hash(hash) => {
|
||||
match hash.direction {
|
||||
Direction::Sent => committed_sent.union_mut(&hash.idx),
|
||||
@@ -259,6 +314,7 @@ impl<'a> TranscriptProofBuilder<'a> {
|
||||
Self {
|
||||
commitment_kinds: DEFAULT_COMMITMENT_KINDS.to_vec(),
|
||||
transcript,
|
||||
encoding_tree,
|
||||
hash_secrets,
|
||||
committed_sent,
|
||||
committed_recv,
|
||||
@@ -356,6 +412,7 @@ impl<'a> TranscriptProofBuilder<'a> {
|
||||
transcript: self
|
||||
.transcript
|
||||
.to_partial(self.query_idx.sent.clone(), self.query_idx.recv.clone()),
|
||||
encoding_proof: None,
|
||||
hash_secrets: Vec::new(),
|
||||
};
|
||||
let mut uncovered_query_idx = self.query_idx.clone();
|
||||
@@ -367,6 +424,46 @@ impl<'a> TranscriptProofBuilder<'a> {
|
||||
// self.commitment_kinds.
|
||||
if let Some(kind) = commitment_kinds_iter.next() {
|
||||
match kind {
|
||||
TranscriptCommitmentKind::Encoding => {
|
||||
let Some(encoding_tree) = self.encoding_tree else {
|
||||
// Proceeds to the next preferred commitment kind if encoding tree is
|
||||
// not available.
|
||||
continue;
|
||||
};
|
||||
|
||||
let (sent_dir_idxs, sent_uncovered) = uncovered_query_idx.sent.cover_by(
|
||||
encoding_tree
|
||||
.transcript_indices()
|
||||
.filter(|(dir, _)| *dir == Direction::Sent),
|
||||
|(_, idx)| idx,
|
||||
);
|
||||
// Uncovered ranges will be checked with ranges of the next
|
||||
// preferred commitment kind.
|
||||
uncovered_query_idx.sent = sent_uncovered;
|
||||
|
||||
let (recv_dir_idxs, recv_uncovered) = uncovered_query_idx.recv.cover_by(
|
||||
encoding_tree
|
||||
.transcript_indices()
|
||||
.filter(|(dir, _)| *dir == Direction::Received),
|
||||
|(_, idx)| idx,
|
||||
);
|
||||
uncovered_query_idx.recv = recv_uncovered;
|
||||
|
||||
let dir_idxs = sent_dir_idxs
|
||||
.into_iter()
|
||||
.chain(recv_dir_idxs)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Skip proof generation if there are no committed ranges that can cover the
|
||||
// query ranges.
|
||||
if !dir_idxs.is_empty() {
|
||||
transcript_proof.encoding_proof = Some(
|
||||
encoding_tree
|
||||
.proof(dir_idxs.into_iter())
|
||||
.expect("subsequences were checked to be in tree"),
|
||||
);
|
||||
}
|
||||
}
|
||||
TranscriptCommitmentKind::Hash { alg } => {
|
||||
let (sent_hashes, sent_uncovered) = uncovered_query_idx.sent.cover_by(
|
||||
self.hash_secrets.iter().filter(|hash| {
|
||||
@@ -493,10 +590,46 @@ mod tests {
|
||||
use rstest::rstest;
|
||||
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
|
||||
use crate::hash::{Blinder, HashAlgId};
|
||||
use crate::{
|
||||
fixtures::{encoder_secret, encoding_provider},
|
||||
hash::{Blake3, Blinder, HashAlgId},
|
||||
transcript::TranscriptCommitConfigBuilder,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[rstest]
|
||||
fn test_verify_missing_encoding_commitment_root() {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let idxs = vec![(Direction::Received, RangeSet::from(0..transcript.len().1))];
|
||||
let encoding_tree = EncodingTree::new(
|
||||
&Blake3::default(),
|
||||
&idxs,
|
||||
&encoding_provider(transcript.sent(), transcript.received()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
|
||||
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
|
||||
|
||||
builder.reveal_recv(&(0..transcript.len().1)).unwrap();
|
||||
|
||||
let transcript_proof = builder.build().unwrap();
|
||||
|
||||
let provider = HashProvider::default();
|
||||
let err = transcript_proof
|
||||
.verify_with_provider(
|
||||
&provider,
|
||||
&transcript.length(),
|
||||
Some(&encoder_secret()),
|
||||
&[],
|
||||
)
|
||||
.err()
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(err.kind, ErrorKind::Encoding));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_reveal_range_out_of_bounds() {
|
||||
let transcript = Transcript::new(
|
||||
@@ -516,7 +649,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_reveal_missing_commitment() {
|
||||
fn test_reveal_missing_encoding_tree() {
|
||||
let transcript = Transcript::new(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
@@ -565,6 +698,7 @@ mod tests {
|
||||
.verify_with_provider(
|
||||
&provider,
|
||||
&transcript.length(),
|
||||
None,
|
||||
&[TranscriptCommitment::Hash(commitment)],
|
||||
)
|
||||
.unwrap();
|
||||
@@ -614,6 +748,7 @@ mod tests {
|
||||
.verify_with_provider(
|
||||
&provider,
|
||||
&transcript.length(),
|
||||
None,
|
||||
&[TranscriptCommitment::Hash(commitment)],
|
||||
)
|
||||
.unwrap_err();
|
||||
@@ -629,19 +764,24 @@ mod tests {
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256,
|
||||
},
|
||||
TranscriptCommitmentKind::Encoding,
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256,
|
||||
},
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256,
|
||||
},
|
||||
TranscriptCommitmentKind::Encoding,
|
||||
]);
|
||||
|
||||
assert_eq!(
|
||||
builder.commitment_kinds,
|
||||
vec![TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256
|
||||
},]
|
||||
vec![
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256
|
||||
},
|
||||
TranscriptCommitmentKind::Encoding
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -651,7 +791,7 @@ mod tests {
|
||||
RangeSet::from([0..10, 12..30]),
|
||||
true,
|
||||
)]
|
||||
#[case::reveal_all_rangesets_with_single_superset_range(
|
||||
#[case::reveal_all_rangesets_with_superset_ranges(
|
||||
vec![RangeSet::from([0..1]), RangeSet::from([1..2, 8..9]), RangeSet::from([2..4, 6..8]), RangeSet::from([2..3, 6..7]), RangeSet::from([9..12])],
|
||||
RangeSet::from([0..4, 6..9]),
|
||||
true,
|
||||
@@ -682,30 +822,29 @@ mod tests {
|
||||
false,
|
||||
)]
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
fn test_reveal_multiple_rangesets_with_one_rangeset(
|
||||
fn test_reveal_mutliple_rangesets_with_one_rangeset(
|
||||
#[case] commit_recv_rangesets: Vec<RangeSet<usize>>,
|
||||
#[case] reveal_recv_rangeset: RangeSet<usize>,
|
||||
#[case] success: bool,
|
||||
) {
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
|
||||
// Create hash commitments for each rangeset
|
||||
let mut secrets = Vec::new();
|
||||
// Encoding commitment kind
|
||||
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
|
||||
for rangeset in commit_recv_rangesets.iter() {
|
||||
let blinder: crate::hash::Blinder = rng.random();
|
||||
|
||||
let secret = PlaintextHashSecret {
|
||||
direction: Direction::Received,
|
||||
idx: rangeset.clone(),
|
||||
alg: HashAlgId::BLAKE3,
|
||||
blinder,
|
||||
};
|
||||
secrets.push(TranscriptSecret::Hash(secret));
|
||||
transcript_commitment_builder.commit_recv(rangeset).unwrap();
|
||||
}
|
||||
|
||||
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
|
||||
|
||||
let encoding_tree = EncodingTree::new(
|
||||
&Blake3::default(),
|
||||
transcripts_commitment_config.iter_encoding(),
|
||||
&encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
|
||||
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
|
||||
|
||||
if success {
|
||||
@@ -758,34 +897,27 @@ mod tests {
|
||||
#[case] uncovered_sent_rangeset: RangeSet<usize>,
|
||||
#[case] uncovered_recv_rangeset: RangeSet<usize>,
|
||||
) {
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
|
||||
// Create hash commitments for each rangeset
|
||||
let mut secrets = Vec::new();
|
||||
// Encoding commitment kind
|
||||
let mut transcript_commitment_builder = TranscriptCommitConfigBuilder::new(&transcript);
|
||||
for rangeset in commit_sent_rangesets.iter() {
|
||||
let blinder: crate::hash::Blinder = rng.random();
|
||||
let secret = PlaintextHashSecret {
|
||||
direction: Direction::Sent,
|
||||
idx: rangeset.clone(),
|
||||
alg: HashAlgId::BLAKE3,
|
||||
blinder,
|
||||
};
|
||||
secrets.push(TranscriptSecret::Hash(secret));
|
||||
transcript_commitment_builder.commit_sent(rangeset).unwrap();
|
||||
}
|
||||
for rangeset in commit_recv_rangesets.iter() {
|
||||
let blinder: crate::hash::Blinder = rng.random();
|
||||
let secret = PlaintextHashSecret {
|
||||
direction: Direction::Received,
|
||||
idx: rangeset.clone(),
|
||||
alg: HashAlgId::BLAKE3,
|
||||
blinder,
|
||||
};
|
||||
secrets.push(TranscriptSecret::Hash(secret));
|
||||
transcript_commitment_builder.commit_recv(rangeset).unwrap();
|
||||
}
|
||||
|
||||
let transcripts_commitment_config = transcript_commitment_builder.build().unwrap();
|
||||
|
||||
let encoding_tree = EncodingTree::new(
|
||||
&Blake3::default(),
|
||||
transcripts_commitment_config.iter_encoding(),
|
||||
&encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let secrets = vec![TranscriptSecret::Encoding(encoding_tree)];
|
||||
let mut builder = TranscriptProofBuilder::new(&transcript, &secrets);
|
||||
builder.reveal_sent(&reveal_sent_rangeset).unwrap();
|
||||
builder.reveal_recv(&reveal_recv_rangeset).unwrap();
|
||||
|
||||
@@ -87,13 +87,15 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: S,
|
||||
verifier_socket: S,
|
||||
req_tx: Sender<AttestationRequest>,
|
||||
resp_rx: Receiver<Attestation>,
|
||||
uri: &str,
|
||||
extra_headers: Vec<(&str, &str)>,
|
||||
example_type: &ExampleType,
|
||||
) -> Result<()> {
|
||||
let mut verifier_socket = verifier_socket.compat();
|
||||
|
||||
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
|
||||
let server_port: u16 = env::var("SERVER_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
@@ -115,37 +117,36 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
.build()?,
|
||||
)
|
||||
.build()?,
|
||||
socket.compat(),
|
||||
&mut verifier_socket,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Open a TCP connection to the server.
|
||||
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?;
|
||||
let client_socket = tokio::net::TcpStream::connect((server_host, server_port))
|
||||
.await?
|
||||
.compat();
|
||||
|
||||
// Bind the prover to the server connection.
|
||||
let (tls_connection, prover_fut) = prover
|
||||
.connect(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
// (Optional) Set up TLS client authentication if required by the server.
|
||||
.client_auth((
|
||||
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
|
||||
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
|
||||
))
|
||||
.build()?,
|
||||
client_socket.compat(),
|
||||
)
|
||||
.await?;
|
||||
let (tls_connection, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
// (Optional) Set up TLS client authentication if required by the server.
|
||||
.client_auth((
|
||||
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
|
||||
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
|
||||
))
|
||||
.build()?,
|
||||
)?;
|
||||
let tls_connection = TokioIo::new(tls_connection.compat());
|
||||
|
||||
// Spawn the prover task to be run concurrently in the background.
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
|
||||
|
||||
// Attach the hyper HTTP client to the connection.
|
||||
let (mut request_sender, connection) =
|
||||
@@ -180,7 +181,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// The prover task should be done now, so we can await it.
|
||||
let prover = prover_task.await??;
|
||||
let (prover, _, verifier_socket) = prover_task.await??;
|
||||
|
||||
// Parse the HTTP transcript.
|
||||
let transcript = HttpTranscript::parse(prover.transcript())?;
|
||||
@@ -222,7 +223,8 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
|
||||
let request_config = builder.build()?;
|
||||
|
||||
let (attestation, secrets) = notarize(prover, &request_config, req_tx, resp_rx).await?;
|
||||
let (attestation, secrets) =
|
||||
notarize(prover, &request_config, verifier_socket, req_tx, resp_rx).await?;
|
||||
|
||||
// Write the attestation to disk.
|
||||
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
|
||||
@@ -242,9 +244,10 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn notarize(
|
||||
async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
|
||||
mut prover: Prover<Committed>,
|
||||
config: &RequestConfig,
|
||||
mut verifier_socket: S,
|
||||
request_tx: Sender<AttestationRequest>,
|
||||
attestation_rx: Receiver<Attestation>,
|
||||
) -> Result<(Attestation, Secrets)> {
|
||||
@@ -260,11 +263,13 @@ async fn notarize(
|
||||
transcript_commitments,
|
||||
transcript_secrets,
|
||||
..
|
||||
} = prover.prove(&disclosure_config).await?;
|
||||
} = prover
|
||||
.prove(&disclosure_config, &mut verifier_socket)
|
||||
.await?;
|
||||
|
||||
let transcript = prover.transcript().clone();
|
||||
let tls_transcript = prover.tls_transcript().clone();
|
||||
prover.close().await?;
|
||||
prover.close(&mut verifier_socket).await?;
|
||||
|
||||
// Build an attestation request.
|
||||
let mut builder = AttestationRequest::builder(config);
|
||||
@@ -307,10 +312,12 @@ async fn notarize(
|
||||
}
|
||||
|
||||
async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: S,
|
||||
prover_socket: S,
|
||||
request_rx: Receiver<AttestationRequest>,
|
||||
attestation_tx: Sender<Attestation>,
|
||||
) -> Result<()> {
|
||||
let mut prover_socket = prover_socket.compat();
|
||||
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
@@ -322,24 +329,29 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
.unwrap();
|
||||
|
||||
let verifier = Verifier::new(verifier_config)
|
||||
.commit(socket.compat())
|
||||
.commit(&mut prover_socket)
|
||||
.await?
|
||||
.accept()
|
||||
.accept(&mut prover_socket)
|
||||
.await?
|
||||
.run()
|
||||
.run(&mut prover_socket)
|
||||
.await?;
|
||||
|
||||
let (
|
||||
VerifierOutput {
|
||||
transcript_commitments,
|
||||
encoder_secret,
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier.verify().await?.accept().await?;
|
||||
) = verifier
|
||||
.verify(&mut prover_socket)
|
||||
.await?
|
||||
.accept(&mut prover_socket)
|
||||
.await?;
|
||||
|
||||
let tls_transcript = verifier.tls_transcript().clone();
|
||||
|
||||
verifier.close().await?;
|
||||
verifier.close(&mut prover_socket).await?;
|
||||
|
||||
let sent_len = tls_transcript
|
||||
.sent()
|
||||
@@ -392,6 +404,10 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
.server_ephemeral_key(tls_transcript.server_ephemeral_key().clone())
|
||||
.transcript_commitments(transcript_commitments);
|
||||
|
||||
if let Some(encoder_secret) = encoder_secret {
|
||||
builder.encoder_secret(encoder_secret);
|
||||
}
|
||||
|
||||
let attestation = builder.build(&provider)?;
|
||||
|
||||
// Send attestation to prover.
|
||||
|
||||
@@ -73,6 +73,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
server_addr: &SocketAddr,
|
||||
uri: &str,
|
||||
) -> Result<()> {
|
||||
let mut verifier_socket = verifier_socket.compat();
|
||||
|
||||
let uri = uri.parse::<Uri>().unwrap();
|
||||
assert_eq!(uri.scheme().unwrap().as_str(), "https");
|
||||
let server_domain = uri.authority().unwrap().host();
|
||||
@@ -93,32 +95,30 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
.build()?,
|
||||
)
|
||||
.build()?,
|
||||
verifier_socket.compat(),
|
||||
&mut verifier_socket,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Open a TCP connection to the server.
|
||||
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
|
||||
let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
|
||||
|
||||
// Bind the prover to the server connection.
|
||||
let (tls_connection, prover_fut) = prover
|
||||
.connect(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
client_socket.compat(),
|
||||
)
|
||||
.await?;
|
||||
let (tls_connection, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
)?;
|
||||
|
||||
let tls_connection = TokioIo::new(tls_connection.compat());
|
||||
|
||||
// Spawn the Prover to run in the background.
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
|
||||
|
||||
// MPC-TLS Handshake.
|
||||
let (mut request_sender, connection) =
|
||||
@@ -140,7 +140,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// Create proof for the Verifier.
|
||||
let mut prover = prover_task.await??;
|
||||
let (mut prover, _, mut verifier_socket) = prover_task.await??;
|
||||
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
|
||||
@@ -173,8 +173,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
|
||||
let config = builder.build()?;
|
||||
|
||||
prover.prove(&config).await?;
|
||||
prover.close().await?;
|
||||
prover.prove(&config, &mut verifier_socket).await?;
|
||||
prover.close(&mut verifier_socket).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -183,6 +183,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: T,
|
||||
) -> Result<PartialTranscript> {
|
||||
let mut socket = socket.compat();
|
||||
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
@@ -194,7 +196,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
let verifier = Verifier::new(verifier_config);
|
||||
|
||||
// Validate the proposed configuration and then run the TLS commitment protocol.
|
||||
let verifier = verifier.commit(socket.compat()).await?;
|
||||
let verifier = verifier.commit(&mut socket).await?;
|
||||
|
||||
// This is the opportunity to ensure the prover does not attempt to overload the
|
||||
// verifier.
|
||||
@@ -212,21 +214,21 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
};
|
||||
|
||||
if reject.is_some() {
|
||||
verifier.reject(reject).await?;
|
||||
verifier.reject(&mut socket, reject).await?;
|
||||
return Err(anyhow::anyhow!("protocol configuration rejected"));
|
||||
}
|
||||
|
||||
// Runs the TLS commitment protocol to completion.
|
||||
let verifier = verifier.accept().await?.run().await?;
|
||||
let verifier = verifier.accept(&mut socket).await?.run(&mut socket).await?;
|
||||
|
||||
// Validate the proving request and then verify.
|
||||
let verifier = verifier.verify().await?;
|
||||
let verifier = verifier.verify(&mut socket).await?;
|
||||
|
||||
if !verifier.request().server_identity() {
|
||||
let verifier = verifier
|
||||
.reject(Some("expecting to verify the server name"))
|
||||
.reject(&mut socket, Some("expecting to verify the server name"))
|
||||
.await?;
|
||||
verifier.close().await?;
|
||||
verifier.close(&mut socket).await?;
|
||||
return Err(anyhow::anyhow!("prover did not reveal the server name"));
|
||||
}
|
||||
|
||||
@@ -237,9 +239,9 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier.accept().await?;
|
||||
) = verifier.accept(&mut socket).await?;
|
||||
|
||||
verifier.close().await?;
|
||||
verifier.close(&mut socket).await?;
|
||||
|
||||
let server_name = server_name.expect("prover should have revealed server name");
|
||||
let transcript = transcript.expect("prover should have revealed transcript data");
|
||||
|
||||
@@ -31,11 +31,10 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Connect prover and verifier.
|
||||
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
|
||||
let (prover_extra_socket, verifier_extra_socket) = tokio::io::duplex(1 << 23);
|
||||
|
||||
let (_, transcript) = tokio::try_join!(
|
||||
prover(prover_socket, prover_extra_socket, &server_addr, &uri),
|
||||
verifier(verifier_socket, verifier_extra_socket)
|
||||
prover(prover_socket, &server_addr, &uri),
|
||||
verifier(verifier_socket)
|
||||
)?;
|
||||
|
||||
println!("---");
|
||||
|
||||
@@ -46,15 +46,15 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use tracing::instrument;
|
||||
|
||||
#[instrument(skip(verifier_socket, verifier_extra_socket))]
|
||||
#[instrument(skip(verifier_socket))]
|
||||
pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
verifier_socket: T,
|
||||
mut verifier_extra_socket: T,
|
||||
server_addr: &SocketAddr,
|
||||
uri: &str,
|
||||
) -> Result<()> {
|
||||
let uri = uri.parse::<Uri>()?;
|
||||
let mut verifier_socket = verifier_socket.compat();
|
||||
|
||||
let uri = uri.parse::<Uri>()?;
|
||||
if uri.scheme().map(|s| s.as_str()) != Some("https") {
|
||||
return Err(anyhow::anyhow!("URI must use HTTPS scheme"));
|
||||
}
|
||||
@@ -80,32 +80,29 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
.build()?,
|
||||
)
|
||||
.build()?,
|
||||
verifier_socket.compat(),
|
||||
&mut verifier_socket,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Open a TCP connection to the server.
|
||||
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
|
||||
let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
|
||||
|
||||
// Bind the prover to the server connection.
|
||||
let (tls_connection, prover_fut) = prover
|
||||
.connect(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
client_socket.compat(),
|
||||
)
|
||||
.await?;
|
||||
let (tls_connection, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
// certificate. This is only required for offline testing with the
|
||||
// server-fixture.
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
)?;
|
||||
let tls_connection = TokioIo::new(tls_connection.compat());
|
||||
|
||||
// Spawn the Prover to run in the background.
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
|
||||
|
||||
// MPC-TLS Handshake.
|
||||
let (mut request_sender, connection) =
|
||||
@@ -133,7 +130,7 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
}
|
||||
|
||||
// Create proof for the Verifier.
|
||||
let mut prover = prover_task.await??;
|
||||
let (mut prover, _, mut verifier_socket) = prover_task.await??;
|
||||
|
||||
let transcript = prover.transcript().clone();
|
||||
let mut prove_config_builder = ProveConfig::builder(&transcript);
|
||||
@@ -167,8 +164,8 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
let prove_config = prove_config_builder.build()?;
|
||||
|
||||
// MPC-TLS prove
|
||||
let prover_output = prover.prove(&prove_config).await?;
|
||||
prover.close().await?;
|
||||
let prover_output = prover.prove(&prove_config, &mut verifier_socket).await?;
|
||||
prover.close(&mut verifier_socket).await?;
|
||||
|
||||
// Prove birthdate is more than 18 years ago.
|
||||
let received_commitments = received_commitments(&prover_output.transcript_commitments);
|
||||
@@ -184,8 +181,10 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
|
||||
// Sent zk proof bundle to verifier
|
||||
let serialized_proof = bincode::serialize(&proof_bundle)?;
|
||||
verifier_extra_socket.write_all(&serialized_proof).await?;
|
||||
verifier_extra_socket.shutdown().await?;
|
||||
|
||||
let mut verifier_socket = verifier_socket.into_inner();
|
||||
verifier_socket.write_all(&serialized_proof).await?;
|
||||
verifier_socket.shutdown().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -20,11 +20,12 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use tracing::instrument;
|
||||
|
||||
#[instrument(skip(socket, extra_socket))]
|
||||
#[instrument(skip(prover_socket))]
|
||||
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: T,
|
||||
mut extra_socket: T,
|
||||
prover_socket: T,
|
||||
) -> Result<PartialTranscript> {
|
||||
let mut prover_socket = prover_socket.compat();
|
||||
|
||||
let verifier = Verifier::new(
|
||||
VerifierConfig::builder()
|
||||
// Create a root certificate store with the server-fixture's self-signed
|
||||
@@ -37,7 +38,7 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
);
|
||||
|
||||
// Validate the proposed configuration and then run the TLS commitment protocol.
|
||||
let verifier = verifier.commit(socket.compat()).await?;
|
||||
let verifier = verifier.commit(&mut prover_socket).await?;
|
||||
|
||||
// This is the opportunity to ensure the prover does not attempt to overload the
|
||||
// verifier.
|
||||
@@ -55,24 +56,29 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
};
|
||||
|
||||
if reject.is_some() {
|
||||
verifier.reject(reject).await?;
|
||||
verifier.reject(&mut prover_socket, reject).await?;
|
||||
return Err(anyhow::anyhow!("protocol configuration rejected"));
|
||||
}
|
||||
|
||||
// Runs the TLS commitment protocol to completion.
|
||||
let verifier = verifier.accept().await?.run().await?;
|
||||
let verifier = verifier
|
||||
.accept(&mut prover_socket)
|
||||
.await?
|
||||
.run(&mut prover_socket)
|
||||
.await?;
|
||||
|
||||
// Validate the proving request and then verify.
|
||||
let verifier = verifier.verify().await?;
|
||||
let verifier = verifier.verify(&mut prover_socket).await?;
|
||||
let request = verifier.request();
|
||||
|
||||
if !request.server_identity() || request.reveal().is_none() {
|
||||
let verifier = verifier
|
||||
.reject(Some(
|
||||
"expecting to verify the server name and transcript data",
|
||||
))
|
||||
.reject(
|
||||
&mut prover_socket,
|
||||
Some("expecting to verify the server name and transcript data"),
|
||||
)
|
||||
.await?;
|
||||
verifier.close().await?;
|
||||
verifier.close(&mut prover_socket).await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"prover did not reveal the server name and transcript data"
|
||||
));
|
||||
@@ -86,10 +92,9 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier.accept().await?;
|
||||
|
||||
verifier.close().await?;
|
||||
) = verifier.accept(&mut prover_socket).await?;
|
||||
|
||||
verifier.close(&mut prover_socket).await?;
|
||||
let server_name = server_name.expect("server name should be present");
|
||||
let transcript = transcript.expect("transcript should be present");
|
||||
|
||||
@@ -126,7 +131,9 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
|
||||
// Receive ZKProof information from prover
|
||||
let mut buf = Vec::new();
|
||||
extra_socket.read_to_end(&mut buf).await?;
|
||||
|
||||
let mut prover_socket = prover_socket.into_inner();
|
||||
prover_socket.read_to_end(&mut buf).await?;
|
||||
|
||||
if buf.is_empty() {
|
||||
return Err(anyhow::anyhow!("No ZK proof data received from prover"));
|
||||
|
||||
@@ -23,7 +23,8 @@ use crate::{
|
||||
};
|
||||
|
||||
pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<ProverMetrics> {
|
||||
let verifier_io = Meter::new(provider.provide_proto_io().await?);
|
||||
let mut verifier_io = Meter::new(provider.provide_proto_io().await?);
|
||||
let mut server_io = provider.provide_server_io().await?;
|
||||
|
||||
let sent = verifier_io.sent();
|
||||
let recv = verifier_io.recv();
|
||||
@@ -49,7 +50,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
.build()
|
||||
}?)
|
||||
.build()?,
|
||||
verifier_io,
|
||||
&mut verifier_io,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -58,19 +59,18 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
let uploaded_preprocess = sent.load(Ordering::Relaxed);
|
||||
let downloaded_preprocess = recv.load(Ordering::Relaxed);
|
||||
|
||||
let (mut conn, prover_fut) = prover
|
||||
.connect(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
provider.provide_server_io().await?,
|
||||
)
|
||||
.await?;
|
||||
let (mut conn, prover) = prover.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
)?;
|
||||
|
||||
let (_, mut prover) = futures::try_join!(
|
||||
let mut prover = prover.connect(&mut server_io, &mut verifier_io);
|
||||
|
||||
futures::try_join!(
|
||||
async {
|
||||
let request = format!(
|
||||
"GET /bytes?size={} HTTP/1.1\r\nConnection: close\r\nData: {}\r\n\r\n",
|
||||
@@ -87,8 +87,9 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
|
||||
Ok(())
|
||||
},
|
||||
prover_fut.map_err(anyhow::Error::from)
|
||||
(&mut prover).map_err(anyhow::Error::from)
|
||||
)?;
|
||||
let mut prover = prover.finish()?;
|
||||
|
||||
let time_online = time_start_online.elapsed().as_millis();
|
||||
let uploaded_online = sent.load(Ordering::Relaxed) - uploaded_preprocess;
|
||||
@@ -118,8 +119,8 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
|
||||
let prove_config = builder.build()?;
|
||||
|
||||
prover.prove(&prove_config).await?;
|
||||
prover.close().await?;
|
||||
prover.prove(&prove_config, &mut verifier_io).await?;
|
||||
prover.close(&mut verifier_io).await?;
|
||||
|
||||
let time_total = time_start.elapsed().as_millis();
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ use tlsn_server_fixture_certs::CA_CERT_DER;
|
||||
use crate::IoProvider;
|
||||
|
||||
pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()> {
|
||||
let mut prover_io = provider.provide_proto_io().await?;
|
||||
|
||||
let verifier = Verifier::new(
|
||||
VerifierConfig::builder()
|
||||
.root_store(RootCertStore {
|
||||
@@ -22,12 +24,16 @@ pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()
|
||||
let verifier = verifier
|
||||
.commit(provider.provide_proto_io().await?)
|
||||
.await?
|
||||
.accept()
|
||||
.accept(&mut prover_io)
|
||||
.await?
|
||||
.run()
|
||||
.run(&mut prover_io)
|
||||
.await?;
|
||||
let (_, verifier) = verifier.verify().await?.accept().await?;
|
||||
verifier.close().await?;
|
||||
let (_, verifier) = verifier
|
||||
.verify(&mut prover_io)
|
||||
.await?
|
||||
.accept(&mut prover_io)
|
||||
.await?;
|
||||
verifier.close(&mut prover_io).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -28,6 +28,8 @@ const MAX_RECV_DATA: usize = 1 << 11;
|
||||
crate::test!("basic", prover, verifier);
|
||||
|
||||
async fn prover(provider: &IoProvider) {
|
||||
let mut verifier_io = provider.provide_proto_io().await.unwrap();
|
||||
|
||||
let prover = Prover::new(ProverConfig::builder().build().unwrap())
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
@@ -41,13 +43,15 @@ async fn prover(provider: &IoProvider) {
|
||||
)
|
||||
.build()
|
||||
.unwrap(),
|
||||
provider.provide_proto_io().await.unwrap(),
|
||||
&mut verifier_io,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (tls_connection, prover_fut) = prover
|
||||
.connect(
|
||||
let server_io = provider.provide_server_io().await.unwrap();
|
||||
|
||||
let (tls_connection, prover) = prover
|
||||
.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
|
||||
.root_store(RootCertStore {
|
||||
@@ -55,12 +59,10 @@ async fn prover(provider: &IoProvider) {
|
||||
})
|
||||
.build()
|
||||
.unwrap(),
|
||||
provider.provide_server_io().await.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let prover_task = spawn(prover_fut);
|
||||
let prover_task = spawn(prover.run(server_io, verifier_io));
|
||||
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(FuturesIo::new(tls_connection))
|
||||
@@ -87,7 +89,7 @@ async fn prover(provider: &IoProvider) {
|
||||
|
||||
let _ = response.into_body().collect().await.unwrap().to_bytes();
|
||||
|
||||
let mut prover = prover_task.await.unwrap().unwrap();
|
||||
let (mut prover, _, mut verifier_io) = prover_task.await.unwrap().unwrap();
|
||||
|
||||
let (sent_len, recv_len) = prover.transcript().len();
|
||||
|
||||
@@ -114,11 +116,13 @@ async fn prover(provider: &IoProvider) {
|
||||
|
||||
let config = builder.build().unwrap();
|
||||
|
||||
prover.prove(&config).await.unwrap();
|
||||
prover.close().await.unwrap();
|
||||
prover.prove(&config, &mut verifier_io).await.unwrap();
|
||||
prover.close(&mut verifier_io).await.unwrap();
|
||||
}
|
||||
|
||||
async fn verifier(provider: &IoProvider) {
|
||||
let mut prover_io = provider.provide_proto_io().await.unwrap();
|
||||
|
||||
let config = VerifierConfig::builder()
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
@@ -127,13 +131,13 @@ async fn verifier(provider: &IoProvider) {
|
||||
.unwrap();
|
||||
|
||||
let verifier = Verifier::new(config)
|
||||
.commit(provider.provide_proto_io().await.unwrap())
|
||||
.commit(&mut prover_io)
|
||||
.await
|
||||
.unwrap()
|
||||
.accept()
|
||||
.accept(&mut prover_io)
|
||||
.await
|
||||
.unwrap()
|
||||
.run()
|
||||
.run(&mut prover_io)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -144,9 +148,15 @@ async fn verifier(provider: &IoProvider) {
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier.verify().await.unwrap().accept().await.unwrap();
|
||||
) = verifier
|
||||
.verify(&mut prover_io)
|
||||
.await
|
||||
.unwrap()
|
||||
.accept(&mut prover_io)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
verifier.close().await.unwrap();
|
||||
verifier.close(&mut prover_io).await.unwrap();
|
||||
|
||||
let ServerName::Dns(server_name) = server_name.unwrap();
|
||||
|
||||
|
||||
@@ -66,7 +66,6 @@ rand_chacha = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tls-server-fixture = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tokio-util = { workspace = true, features = ["compat"] }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
@@ -378,15 +378,29 @@ impl MpcTlsLeader {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Defers decryption of any incoming messages.
|
||||
/// Enables or disables the decryption of any incoming messages.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = false;
|
||||
self.notifier.clear();
|
||||
pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
|
||||
self.is_decrypting = enable;
|
||||
|
||||
if enable {
|
||||
self.notifier.set();
|
||||
} else {
|
||||
self.notifier.clear();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns if incoming messages are decrypted.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.is_decrypting
|
||||
}
|
||||
|
||||
/// Stops the actor.
|
||||
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
|
||||
ctx.stop();
|
||||
|
||||
@@ -32,10 +32,14 @@ impl MpcTlsLeaderCtrl {
|
||||
Self { address }
|
||||
}
|
||||
|
||||
/// Defers decryption of any incoming messages.
|
||||
pub async fn defer_decryption(&self) -> Result<(), MpcTlsError> {
|
||||
/// Enables or disables the decryption of any incoming messages.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - Whether to enable or disable decryption.
|
||||
pub async fn enable_decryption(&self, enable: bool) -> Result<(), MpcTlsError> {
|
||||
self.address
|
||||
.send(DeferDecryption)
|
||||
.send(EnableDecryption { enable })
|
||||
.await
|
||||
.map_err(MpcTlsError::actor)?
|
||||
}
|
||||
@@ -981,7 +985,7 @@ impl Handler<BackendMsgServerClosed> for MpcTlsLeader {
|
||||
}
|
||||
}
|
||||
|
||||
impl Dispatch<MpcTlsLeader> for DeferDecryption {
|
||||
impl Dispatch<MpcTlsLeader> for EnableDecryption {
|
||||
fn dispatch<R: FnOnce(Self::Return) + Send>(
|
||||
self,
|
||||
actor: &mut MpcTlsLeader,
|
||||
@@ -992,13 +996,13 @@ impl Dispatch<MpcTlsLeader> for DeferDecryption {
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<DeferDecryption> for MpcTlsLeader {
|
||||
impl Handler<EnableDecryption> for MpcTlsLeader {
|
||||
async fn handle(
|
||||
&mut self,
|
||||
_msg: DeferDecryption,
|
||||
msg: EnableDecryption,
|
||||
_ctx: &mut LudiCtx<Self>,
|
||||
) -> <DeferDecryption as Message>::Return {
|
||||
self.defer_decryption().await
|
||||
) -> <EnableDecryption as Message>::Return {
|
||||
self.enable_decryption(msg.enable)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1048,7 +1052,7 @@ pub enum MpcTlsLeaderMsg {
|
||||
BackendMsgGetNotify(BackendMsgGetNotify),
|
||||
BackendMsgIsEmpty(BackendMsgIsEmpty),
|
||||
BackendMsgServerClosed(BackendMsgServerClosed),
|
||||
DeferDecryption(DeferDecryption),
|
||||
DeferDecryption(EnableDecryption),
|
||||
Stop(Stop),
|
||||
}
|
||||
|
||||
@@ -1083,7 +1087,7 @@ pub enum MpcTlsLeaderMsgReturn {
|
||||
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
|
||||
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
|
||||
BackendMsgServerClosed(<BackendMsgServerClosed as Message>::Return),
|
||||
DeferDecryption(<DeferDecryption as Message>::Return),
|
||||
DeferDecryption(<EnableDecryption as Message>::Return),
|
||||
Stop(<Stop as Message>::Return),
|
||||
}
|
||||
|
||||
@@ -1732,23 +1736,25 @@ impl Wrap<BackendMsgServerClosed> for MpcTlsLeaderMsg {
|
||||
}
|
||||
}
|
||||
|
||||
/// Message to start deferring the decryption.
|
||||
/// Message to enable or disable the decryption of messages.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
pub struct DeferDecryption;
|
||||
pub struct EnableDecryption {
|
||||
pub enable: bool,
|
||||
}
|
||||
|
||||
impl Message for DeferDecryption {
|
||||
impl Message for EnableDecryption {
|
||||
type Return = Result<(), MpcTlsError>;
|
||||
}
|
||||
|
||||
impl From<DeferDecryption> for MpcTlsLeaderMsg {
|
||||
fn from(value: DeferDecryption) -> Self {
|
||||
impl From<EnableDecryption> for MpcTlsLeaderMsg {
|
||||
fn from(value: EnableDecryption) -> Self {
|
||||
MpcTlsLeaderMsg::DeferDecryption(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Wrap<DeferDecryption> for MpcTlsLeaderMsg {
|
||||
fn unwrap_return(ret: Self::Return) -> Result<<DeferDecryption as Message>::Return, Error> {
|
||||
impl Wrap<EnableDecryption> for MpcTlsLeaderMsg {
|
||||
fn unwrap_return(ret: Self::Return) -> Result<<EnableDecryption as Message>::Return, Error> {
|
||||
match ret {
|
||||
Self::Return::DeferDecryption(value) => Ok(value),
|
||||
_ => Err(Error::Wrapper),
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
|
||||
use mpz_common::context::test_mt_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
use mpz_memory_core::correlated::Delta;
|
||||
use mpz_ot::{
|
||||
ideal::rcot::ideal_rcot,
|
||||
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
|
||||
};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use tls_client::RootCertStore;
|
||||
use tls_client_async::bind_client;
|
||||
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn mpc_tls_test() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let config = Config::builder()
|
||||
.defer_decryption(false)
|
||||
.max_sent(1 << 13)
|
||||
.max_recv_online(1 << 13)
|
||||
.max_recv(1 << 13)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let (leader, follower) = build_pair(config);
|
||||
|
||||
tokio::try_join!(
|
||||
tokio::spawn(leader_task(leader)),
|
||||
tokio::spawn(follower_task(follower))
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn leader_task(mut leader: MpcTlsLeader) {
|
||||
leader.alloc().unwrap();
|
||||
|
||||
leader.preprocess().await.unwrap();
|
||||
|
||||
let (leader_ctrl, leader_fut) = leader.run();
|
||||
tokio::spawn(async { leader_fut.await.unwrap() });
|
||||
|
||||
let config = tls_client::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(RootCertStore {
|
||||
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
|
||||
})
|
||||
.with_no_client_auth();
|
||||
|
||||
let server_name = SERVER_DOMAIN.try_into().unwrap();
|
||||
|
||||
let client = tls_client::ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(leader_ctrl.clone()),
|
||||
server_name,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
|
||||
|
||||
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
|
||||
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: keep-alive\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 48];
|
||||
conn.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.defer_decryption().await.unwrap();
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: close\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
conn.close().await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
conn.read_to_end(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.stop().await.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn follower_task(mut follower: MpcTlsFollower) {
|
||||
follower.alloc().unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
follower.run().await.unwrap();
|
||||
}
|
||||
|
||||
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let (mut mt_a, mut mt_b) = test_mt_context(8);
|
||||
|
||||
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
|
||||
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
|
||||
|
||||
let delta_a = Delta::new(Block::random(&mut rng));
|
||||
let delta_b = Delta::new(Block::random(&mut rng));
|
||||
|
||||
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
|
||||
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
|
||||
|
||||
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
|
||||
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
|
||||
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
|
||||
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
|
||||
|
||||
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
|
||||
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
|
||||
|
||||
let leader = MpcTlsLeader::new(
|
||||
config.clone(),
|
||||
ctx_a,
|
||||
mpc_a,
|
||||
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
|
||||
rcot_recv_a,
|
||||
);
|
||||
|
||||
let follower = MpcTlsFollower::new(
|
||||
config,
|
||||
ctx_b,
|
||||
mpc_b,
|
||||
rcot_send_b,
|
||||
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
|
||||
);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
[package]
|
||||
name = "tlsn-tls-client-async"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "An async TLS client for TLSNotary"
|
||||
keywords = ["tls", "mpc", "2pc", "client", "async"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.14-pre"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
name = "tls_client_async"
|
||||
|
||||
[features]
|
||||
default = ["tracing"]
|
||||
tracing = ["dep:tracing"]
|
||||
|
||||
[dependencies]
|
||||
tlsn-tls-client = { workspace = true }
|
||||
|
||||
bytes = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["io", "compat"] }
|
||||
tracing = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tls-server-fixture = { workspace = true }
|
||||
|
||||
http-body-util = { workspace = true }
|
||||
hyper = { workspace = true, features = ["client", "http1"] }
|
||||
hyper-util = { workspace = true, features = ["full"] }
|
||||
rstest = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
|
||||
rustls-webpki = { workspace = true }
|
||||
rustls-pki-types = { workspace = true }
|
||||
@@ -1,89 +0,0 @@
|
||||
use bytes::Bytes;
|
||||
use futures::{
|
||||
channel::mpsc::{Receiver, SendError, Sender},
|
||||
sink::SinkMapErr,
|
||||
AsyncRead, AsyncWrite, SinkExt,
|
||||
};
|
||||
use std::{
|
||||
io::{Error as IoError, ErrorKind as IoErrorKind},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio_util::{
|
||||
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
|
||||
io::{CopyToBytes, SinkWriter, StreamReader},
|
||||
};
|
||||
|
||||
type CompatSinkWriter =
|
||||
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
|
||||
|
||||
/// A TLS connection to a server.
|
||||
///
|
||||
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
|
||||
/// communicate with a server using TLS.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This connection is closed on a best-effort basis if this is dropped. To
|
||||
/// ensure a clean close, you should call
|
||||
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
|
||||
/// connection.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsConnection {
|
||||
/// The data to be transmitted to the server is sent to this sink.
|
||||
tx_sender: CompatSinkWriter,
|
||||
/// The data to be received from the server is received from this stream.
|
||||
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
|
||||
}
|
||||
|
||||
impl TlsConnection {
|
||||
/// Creates a new TLS connection.
|
||||
pub(crate) fn new(
|
||||
tx_sender: Sender<Bytes>,
|
||||
rx_receiver: Receiver<Result<Bytes, IoError>>,
|
||||
) -> Self {
|
||||
fn convert_error(err: SendError) -> IoError {
|
||||
if err.is_disconnected() {
|
||||
IoErrorKind::BrokenPipe.into()
|
||||
} else {
|
||||
IoErrorKind::WouldBlock.into()
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
tx_sender: SinkWriter::new(CopyToBytes::new(
|
||||
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
|
||||
))
|
||||
.compat_write(),
|
||||
rx_receiver: StreamReader::new(rx_receiver).compat(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TlsConnection {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize, IoError>> {
|
||||
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsConnection {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
|
||||
Pin::new(&mut self.tx_sender).poll_close(cx)
|
||||
}
|
||||
}
|
||||
@@ -1,269 +0,0 @@
|
||||
//! Provides a TLS client which exposes an async socket.
|
||||
//!
|
||||
//! This library provides the [bind_client] function which attaches a TLS client
|
||||
//! to a socket connection and then exposes a [TlsConnection] object, which
|
||||
//! provides an async socket API for reading and writing cleartext. The TLS
|
||||
//! client will then automatically encrypt and decrypt traffic and forward that
|
||||
//! to the provided socket.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
mod conn;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{
|
||||
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
|
||||
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
|
||||
};
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use tracing::{debug, debug_span, trace, warn, Instrument};
|
||||
|
||||
use tls_client::ClientConnection;
|
||||
|
||||
pub use conn::TlsConnection;
|
||||
|
||||
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
|
||||
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
|
||||
|
||||
/// An error that can occur during a TLS connection.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectionError {
|
||||
#[error(transparent)]
|
||||
TlsError(#[from] tls_client::Error),
|
||||
#[error(transparent)]
|
||||
IOError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// Closed connection data.
|
||||
#[derive(Debug)]
|
||||
pub struct ClosedConnection {
|
||||
/// The connection for the client
|
||||
pub client: ClientConnection,
|
||||
/// Sent plaintext bytes
|
||||
pub sent: Vec<u8>,
|
||||
/// Received plaintext bytes
|
||||
pub recv: Vec<u8>,
|
||||
}
|
||||
|
||||
/// A future which runs the TLS connection to completion.
|
||||
///
|
||||
/// This future must be polled in order for the connection to make progress.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct ConnectionFuture {
|
||||
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
|
||||
}
|
||||
|
||||
impl Future for ConnectionFuture {
|
||||
type Output = Result<ClosedConnection, ConnectionError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.fut.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Binds a client connection to the provided socket.
|
||||
///
|
||||
/// Returns a connection handle and a future which runs the connection to
|
||||
/// completion.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Any connection errors that occur will be returned from the future, not
|
||||
/// [`TlsConnection`].
|
||||
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
|
||||
socket: T,
|
||||
mut client: ClientConnection,
|
||||
) -> (TlsConnection, ConnectionFuture) {
|
||||
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
|
||||
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
|
||||
|
||||
let conn = TlsConnection::new(tx_sender, rx_receiver);
|
||||
|
||||
let fut = async move {
|
||||
client.start().await?;
|
||||
let mut notify = client.get_notify().await?;
|
||||
|
||||
let (mut server_rx, mut server_tx) = socket.split();
|
||||
|
||||
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
|
||||
let mut rx_buf = [0u8; RX_BUF_SIZE];
|
||||
|
||||
let mut handshake_done = false;
|
||||
let mut client_closed = false;
|
||||
let mut server_closed = false;
|
||||
|
||||
let mut sent = Vec::with_capacity(1024);
|
||||
let mut recv = Vec::with_capacity(1024);
|
||||
|
||||
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
|
||||
// We don't start writing application data until the handshake is complete.
|
||||
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
|
||||
|
||||
// Runs both the tx and rx halves of the connection to completion.
|
||||
// This loop does not terminate until the *SERVER* closes the connection and
|
||||
// we've processed all received data. If an error occurs, the `TlsConnection`
|
||||
// channels will be closed and the error will be returned from this future.
|
||||
'conn: loop {
|
||||
// Write all pending TLS data to the server.
|
||||
if client.wants_write() && !client_closed {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("client wants to write");
|
||||
while client.wants_write() {
|
||||
let _sent = client.write_tls_async(&mut server_tx).await?;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("sent {} tls bytes to server", _sent);
|
||||
}
|
||||
server_tx.flush().await?;
|
||||
}
|
||||
|
||||
// Forward received plaintext to `TlsConnection`.
|
||||
while !client.plaintext_is_empty() {
|
||||
let read = client.read_plaintext(&mut rx_buf)?;
|
||||
recv.extend(&rx_buf[..read]);
|
||||
// Ignore if the receiver has hung up.
|
||||
_ = rx_sender
|
||||
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
|
||||
.await;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("forwarded {} plaintext bytes to conn", read);
|
||||
}
|
||||
|
||||
if !client.is_handshaking() && !handshake_done {
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("handshake complete");
|
||||
handshake_done = true;
|
||||
// Start reading application data that needs to be transmitted from the
|
||||
// `TlsConnection`.
|
||||
tx_recv_fut = tx_receiver.next().fuse();
|
||||
}
|
||||
|
||||
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
|
||||
break 'conn;
|
||||
}
|
||||
|
||||
select_biased! {
|
||||
// Reads TLS data from the server and writes it into the client.
|
||||
received = &mut rx_tls_fut => {
|
||||
let received = received?;
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("received {} tls bytes from server", received);
|
||||
|
||||
// Loop until we've processed all the data we received in this read.
|
||||
// Note that we must make one iteration even if `received == 0`.
|
||||
let mut processed = 0;
|
||||
let mut reader = rx_tls_buf[..received].reader();
|
||||
loop {
|
||||
processed += client.read_tls(&mut reader)?;
|
||||
client.process_new_packets().await?;
|
||||
|
||||
debug_assert!(processed <= received);
|
||||
if processed >= received {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("processed {} tls bytes from server", processed);
|
||||
|
||||
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
|
||||
// has closed the socket.
|
||||
if received == 0 {
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("server closed connection");
|
||||
server_closed = true;
|
||||
client.server_closed().await?;
|
||||
// Do not read from the socket again.
|
||||
rx_tls_fut = Fuse::terminated();
|
||||
} else {
|
||||
// Reset the read future so next iteration we can read again.
|
||||
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
|
||||
}
|
||||
}
|
||||
// If we receive None from `TlsConnection`, it has closed, so we
|
||||
// send a close_notify to the server.
|
||||
data = &mut tx_recv_fut => {
|
||||
if let Some(data) = data {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("writing {} plaintext bytes to client", data.len());
|
||||
|
||||
sent.extend(&data);
|
||||
client
|
||||
.write_all_plaintext(&data)
|
||||
.await?;
|
||||
|
||||
tx_recv_fut = tx_receiver.next().fuse();
|
||||
} else {
|
||||
if !server_closed {
|
||||
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
|
||||
#[cfg(feature = "tracing")]
|
||||
warn!("failed to send close_notify to server: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
client_closed = true;
|
||||
|
||||
tx_recv_fut = Fuse::terminated();
|
||||
}
|
||||
}
|
||||
// Waits for a notification from the backend that it is ready to decrypt data.
|
||||
_ = &mut notify => {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("backend is ready to decrypt");
|
||||
|
||||
client.process_new_packets().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
debug!("client shutdown");
|
||||
|
||||
_ = server_tx.close().await;
|
||||
tx_receiver.close();
|
||||
rx_sender.close_channel();
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!(
|
||||
"server close notify: {}, sent: {}, recv: {}",
|
||||
client.received_close_notify(),
|
||||
sent.len(),
|
||||
recv.len()
|
||||
);
|
||||
|
||||
Ok(ClosedConnection { client, sent, recv })
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
let fut = fut.instrument(debug_span!("tls_connection"));
|
||||
|
||||
let fut = ConnectionFuture { fut: Box::pin(fut) };
|
||||
|
||||
(conn, fut)
|
||||
}
|
||||
|
||||
async fn send_close_notify(
|
||||
client: &mut ClientConnection,
|
||||
server_tx: &mut (impl AsyncWrite + Unpin),
|
||||
) -> Result<(), ConnectionError> {
|
||||
#[cfg(feature = "tracing")]
|
||||
trace!("sending close_notify to server");
|
||||
client.send_close_notify().await?;
|
||||
client.process_new_packets().await?;
|
||||
|
||||
// Flush all remaining plaintext
|
||||
while client.wants_write() {
|
||||
client.write_tls_async(server_tx).await?;
|
||||
}
|
||||
server_tx.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,438 +0,0 @@
|
||||
use std::{str, sync::Arc};
|
||||
|
||||
use core::future::Future;
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use http_body_util::{BodyExt as _, Full};
|
||||
use hyper::{body::Bytes, Request, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rstest::{fixture, rstest};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
|
||||
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
|
||||
use tls_server_fixture::{
|
||||
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
|
||||
SERVER_DOMAIN,
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
|
||||
|
||||
// An established client TLS connection
|
||||
struct TlsFixture {
|
||||
client_tls_conn: TlsConnection,
|
||||
// a handle that must be `.await`ed to get the result of a TLS connection
|
||||
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
|
||||
}
|
||||
|
||||
// Sets up a TLS connection between client and server and sends a hello message
|
||||
#[fixture]
|
||||
async fn set_up_tls() -> TlsFixture {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
|
||||
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
|
||||
|
||||
let mut root_store = tls_client::RootCertStore::empty();
|
||||
root_store
|
||||
.roots
|
||||
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(RustCryptoBackend::new()),
|
||||
ServerName::try_from(SERVER_DOMAIN).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
|
||||
|
||||
let closed_tls_task = tokio::spawn(tls_fut);
|
||||
|
||||
client_tls_conn
|
||||
.write_all(&pad("expecting you to send back hello".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// give the server some time to respond
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
|
||||
let mut plaintext = vec![0u8; 320];
|
||||
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
|
||||
let s = str::from_utf8(&plaintext[0..n]).unwrap();
|
||||
|
||||
assert_eq!(s, "hello");
|
||||
|
||||
TlsFixture {
|
||||
client_tls_conn,
|
||||
closed_tls_task,
|
||||
}
|
||||
}
|
||||
|
||||
// Expect the async tls client wrapped in `hyper::client` to make a successful
|
||||
// request and receive the expected response
|
||||
#[tokio::test]
|
||||
async fn test_hyper_ok() {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
|
||||
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
|
||||
|
||||
let mut root_store = tls_client::RootCertStore::empty();
|
||||
root_store
|
||||
.roots
|
||||
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
let client = ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(RustCryptoBackend::new()),
|
||||
ServerName::try_from(SERVER_DOMAIN).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
|
||||
|
||||
let closed_tls_task = tokio::spawn(tls_fut);
|
||||
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::spawn(connection);
|
||||
|
||||
let request = Request::builder()
|
||||
.uri(format!("https://{SERVER_DOMAIN}/echo"))
|
||||
.header("Host", SERVER_DOMAIN)
|
||||
.header("Connection", "close")
|
||||
.method("POST")
|
||||
.body(Full::<Bytes>::new("hello".into()))
|
||||
.unwrap();
|
||||
|
||||
let response = request_sender.send_request(request).await.unwrap();
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// Process the response body
|
||||
response.into_body().collect().await.unwrap().to_bytes();
|
||||
|
||||
let _ = server_task.await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server responds to the client's
|
||||
// close_notify but doesn't close the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server responds to the client's
|
||||
// close_notify AND also closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us AND close the socket
|
||||
// after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server sends close_notify first
|
||||
// but doesn't close the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time for server's close_notify to arrive
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect a clean TLS connection closure when server sends close_notify first
|
||||
// AND also closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_close_notify_and_socket_close(
|
||||
set_up_tls: impl Future<Output = TlsFixture>,
|
||||
) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send close_notify back to us after 10 ms
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time for server's close_notify to arrive
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect to be able to read the data after server closes the socket abruptly
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
..
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a hello message
|
||||
client_tls_conn
|
||||
.write_all(&pad("send a hello message".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// try to read some more data
|
||||
let mut buf = vec![0u8; 10];
|
||||
let n = client_tls_conn.read(&mut buf).await.unwrap();
|
||||
|
||||
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
|
||||
}
|
||||
|
||||
// Expect there to be no error when server DOES NOT send close_notify but just
|
||||
// closes the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
|
||||
assert!(!closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect to register a delay when the server delays closing the socket
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
client_tls_conn
|
||||
.write_all(&pad("must_delay_when_closing".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// closing `client_tls_conn` will cause close_notify to be sent by the client
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
use std::time::Instant;
|
||||
let now = Instant::now();
|
||||
// this will resolve when the server stops delaying closing the socket
|
||||
let closed_conn = closed_tls_task.await.unwrap().unwrap();
|
||||
let elapsed = now.elapsed();
|
||||
|
||||
// the elapsed time must be roughly equal to the server's delay
|
||||
// (give or take timing variations)
|
||||
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
|
||||
|
||||
assert!(!closed_conn.client.received_close_notify());
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a corrupted message
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send a corrupted message
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_corrupted_message".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"received corrupt message"
|
||||
);
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a TLS record with a bad MAC
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a TLS record with a bad MAC
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_record_with_bad_mac".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"backend error: Decryption error: \"aead::Error\""
|
||||
);
|
||||
}
|
||||
|
||||
// Expect client to error when server sends a fatal alert
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
closed_tls_task,
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to send us a TLS record with a bad MAC
|
||||
client_tls_conn
|
||||
.write_all(&pad("send_alert".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
client_tls_conn.close().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
closed_tls_task.await.unwrap().err().unwrap().to_string(),
|
||||
"received fatal alert: BadRecordMac"
|
||||
);
|
||||
}
|
||||
|
||||
// Expect an error when trying to write data to a connection which server closed
|
||||
// abruptly
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
|
||||
let TlsFixture {
|
||||
mut client_tls_conn,
|
||||
..
|
||||
} = set_up_tls.await;
|
||||
|
||||
// instruct the server to close the socket
|
||||
client_tls_conn
|
||||
.write_all(&pad("close_socket".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
client_tls_conn.flush().await.unwrap();
|
||||
|
||||
// give enough time to close the socket
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
|
||||
// try to send some more data
|
||||
let res = client_tls_conn
|
||||
.write_all(&pad("more data".to_string()))
|
||||
.await;
|
||||
|
||||
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
|
||||
}
|
||||
|
||||
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
|
||||
fn pad(s: String) -> Vec<u8> {
|
||||
assert!(s.len() <= APP_RECORD_LENGTH);
|
||||
let mut buf = vec![0u8; APP_RECORD_LENGTH];
|
||||
buf[..s.len()].copy_from_slice(s.as_bytes());
|
||||
buf
|
||||
}
|
||||
@@ -457,6 +457,9 @@ impl ConnectionCommon {
|
||||
return Err(Error::CorruptMessage);
|
||||
}
|
||||
|
||||
// Process outgoing plaintext buffer and encrypt messages.
|
||||
self.flush_plaintext().await?;
|
||||
|
||||
// Process new messages.
|
||||
while let Some(msg) = self.message_deframer.frames.pop_front() {
|
||||
// If we're not decrypting yet, we process it immediately. Otherwise it will be
|
||||
@@ -508,25 +511,22 @@ impl ConnectionCommon {
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Write buffer into connection.
|
||||
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
if let Ok(st) = &mut self.state {
|
||||
st.perhaps_write_key_update(&mut self.common_state).await;
|
||||
/// Writes plaintext `buf` into an internal buffer. May not fully process the
|
||||
/// whole buffer and returns the processed length.
|
||||
pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
if buf.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return Ok(0);
|
||||
}
|
||||
self.common_state.send_some_plaintext(buf).await
|
||||
|
||||
let len = self.sendable_plaintext.append_limited_copy(buf);
|
||||
Ok(len)
|
||||
}
|
||||
|
||||
/// Write entire buffer into connection.
|
||||
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
let mut pos = 0;
|
||||
while pos < buf.len() {
|
||||
pos += self.write_plaintext(&buf[pos..]).await?;
|
||||
}
|
||||
self.backend.flush().await?;
|
||||
while let Some(msg) = self.backend.next_outgoing().await? {
|
||||
self.queue_tls_message(msg);
|
||||
}
|
||||
Ok(pos)
|
||||
/// Writes the entire plaintext `buf` into an internal buffer.
|
||||
pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
|
||||
self.sendable_plaintext.append(buf.to_vec());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read TLS content from `rd`. This method does internal
|
||||
@@ -690,6 +690,11 @@ impl CommonState {
|
||||
self.received_plaintext.is_empty()
|
||||
}
|
||||
|
||||
/// Returns true if the buffer for sendable plaintext is full.
|
||||
pub fn sendable_plaintext_is_full(&self) -> bool {
|
||||
self.sendable_plaintext.is_full()
|
||||
}
|
||||
|
||||
/// Returns true if the connection is currently performing the TLS
|
||||
/// handshake.
|
||||
///
|
||||
@@ -782,15 +787,6 @@ impl CommonState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send plaintext application data, fragmenting and
|
||||
/// encrypting it as it goes out.
|
||||
///
|
||||
/// If internal buffers are too small, this function will not accept
|
||||
/// all the data.
|
||||
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
|
||||
self.send_plain(data, Limit::Yes).await
|
||||
}
|
||||
|
||||
// Changing the keys must not span any fragmented handshake
|
||||
// messages. Otherwise the defragmented messages will have
|
||||
// been protected with two different record layer protections,
|
||||
@@ -931,32 +927,6 @@ impl CommonState {
|
||||
self.sendable_tls.write_to_async(wr).await
|
||||
}
|
||||
|
||||
/// Encrypt and send some plaintext `data`. `limit` controls
|
||||
/// whether the per-connection buffer limits apply.
|
||||
///
|
||||
/// Returns the number of bytes written from `data`: this might
|
||||
/// be less than `data.len()` if buffer limits were exceeded.
|
||||
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
|
||||
if !self.may_send_application_data {
|
||||
// If we haven't completed handshaking, buffer
|
||||
// plaintext to send once we do.
|
||||
let len = match limit {
|
||||
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
|
||||
Limit::No => self.sendable_plaintext.append(data.to_vec()),
|
||||
};
|
||||
return Ok(len);
|
||||
}
|
||||
|
||||
debug_assert!(self.record_layer.is_encrypting());
|
||||
|
||||
if data.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, limit).await
|
||||
}
|
||||
|
||||
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
|
||||
self.may_send_application_data = true;
|
||||
self.flush_plaintext().await
|
||||
@@ -1012,15 +982,14 @@ impl CommonState {
|
||||
self.sendable_tls.set_limit(limit);
|
||||
}
|
||||
|
||||
/// Send any buffered plaintext. Plaintext is buffered if
|
||||
/// written during handshake.
|
||||
async fn flush_plaintext(&mut self) -> Result<(), Error> {
|
||||
/// Send and encrypt any buffered plaintext. Does nothing during handshake.
|
||||
pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
|
||||
if !self.may_send_application_data {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while let Some(buf) = self.sendable_plaintext.pop() {
|
||||
self.send_plain(&buf, Limit::No).await?;
|
||||
self.send_appdata_encrypt(&buf, Limit::No).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -35,6 +35,15 @@ impl ChunkVecBuffer {
|
||||
self.chunks.is_empty()
|
||||
}
|
||||
|
||||
/// If the buffer has reached limit.
|
||||
pub(crate) fn is_full(&self) -> bool {
|
||||
if let Some(limit) = self.limit {
|
||||
self.len() >= limit
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// How many bytes we're storing
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
let mut len = 0;
|
||||
|
||||
@@ -247,7 +247,8 @@ async fn servered_client_data_sent() {
|
||||
let (mut client, mut server) =
|
||||
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
|
||||
|
||||
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap());
|
||||
assert_eq!(5, client.write_plaintext(b"hello").unwrap());
|
||||
client.flush_plaintext().await.unwrap();
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
send(&mut client, &mut server);
|
||||
@@ -286,7 +287,7 @@ async fn servered_both_data_sent() {
|
||||
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
|
||||
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
@@ -432,7 +433,7 @@ async fn server_close_notify() {
|
||||
|
||||
// check that alerts don't overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
server.send_close_notify();
|
||||
|
||||
receive(&mut server, &mut client);
|
||||
@@ -460,7 +461,8 @@ async fn client_close_notify() {
|
||||
|
||||
// check that alerts don't overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
client.flush_plaintext().await.unwrap();
|
||||
client.send_close_notify().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
@@ -487,7 +489,7 @@ async fn server_closes_uncleanly() {
|
||||
|
||||
// check that unclean EOF reporting does not overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
|
||||
receive(&mut server, &mut client);
|
||||
transfer_eof(&mut client);
|
||||
@@ -518,7 +520,7 @@ async fn client_closes_uncleanly() {
|
||||
|
||||
// check that unclean EOF reporting does not overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
@@ -900,20 +902,9 @@ async fn client_respects_buffer_limit_pre_handshake() {
|
||||
|
||||
client.set_buffer_limit(Some(32));
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
20
|
||||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
12
|
||||
);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
|
||||
client.flush_plaintext().await.unwrap();
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
send(&mut client, &mut server);
|
||||
@@ -953,20 +944,9 @@ async fn client_respects_buffer_limit_post_handshake() {
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
client.set_buffer_limit(Some(48));
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
20
|
||||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap(),
|
||||
6
|
||||
);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
|
||||
assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
|
||||
client.flush_plaintext().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
server.process_new_packets().unwrap();
|
||||
@@ -1211,14 +1191,8 @@ async fn client_complete_io_for_write() {
|
||||
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client
|
||||
.write_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client.write_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_plaintext(b"01234567890123456789").unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let (rdlen, wrlen) = client
|
||||
@@ -1350,7 +1324,8 @@ async fn server_stream_read() {
|
||||
for kt in ALL_KEY_TYPES.iter() {
|
||||
let (mut client, mut server) = make_pair(*kt).await;
|
||||
|
||||
client.write_all_plaintext(b"world").await.unwrap();
|
||||
client.write_all_plaintext(b"world").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
{
|
||||
let mut pipe = ClientSession::new(&mut client);
|
||||
@@ -1366,7 +1341,8 @@ async fn server_streamowned_read() {
|
||||
for kt in ALL_KEY_TYPES.iter() {
|
||||
let (mut client, server) = make_pair(*kt).await;
|
||||
|
||||
client.write_all_plaintext(b"world").await.unwrap();
|
||||
client.write_all_plaintext(b"world").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
{
|
||||
let pipe = ClientSession::new(&mut client);
|
||||
@@ -1385,7 +1361,9 @@ async fn server_streamowned_read() {
|
||||
// errkind: io::ErrorKind::ConnectionAborted,
|
||||
// after: 0,
|
||||
// };
|
||||
// client.write_all_plaintext(b"hello").await.unwrap();
|
||||
// client.write_all_plaintext(b"hello").unwrap();
|
||||
// client.process_new_packets().await.unwrap();
|
||||
//
|
||||
// let mut client_stream = Stream::new(&mut client, &mut pipe);
|
||||
// let rc = client_stream.write(b"world");
|
||||
// assert!(rc.is_err());
|
||||
@@ -1402,7 +1380,9 @@ async fn server_streamowned_read() {
|
||||
// errkind: io::ErrorKind::ConnectionAborted,
|
||||
// after: 1,
|
||||
// };
|
||||
// client.write_all_plaintext(b"hello").await.unwrap();
|
||||
// client.write_all_plaintext(b"hello").unwrap();
|
||||
// client.process_new_packets().await.unwrap();
|
||||
//
|
||||
// let mut client_stream = Stream::new(&mut client, &mut pipe);
|
||||
// let rc = client_stream.write(b"world");
|
||||
// assert_eq!(format!("{:?}", rc), "Ok(5)");
|
||||
@@ -1900,14 +1880,9 @@ async fn servered_write_for_client_appdata() {
|
||||
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
|
||||
do_handshake(&mut client, &mut server).await;
|
||||
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
@@ -2019,11 +1994,10 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
|
||||
async fn servered_write_for_client_handshake() {
|
||||
let (mut client, mut server) = make_pair(KeyType::Rsa).await;
|
||||
|
||||
client
|
||||
.write_all_plaintext(b"01234567890123456789")
|
||||
.await
|
||||
.unwrap();
|
||||
client.write_all_plaintext(b"0123456789").await.unwrap();
|
||||
client.write_all_plaintext(b"01234567890123456789").unwrap();
|
||||
client.write_all_plaintext(b"0123456789").unwrap();
|
||||
client.process_new_packets().await.unwrap();
|
||||
|
||||
{
|
||||
let mut pipe = ServerSession::new(&mut server);
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
|
||||
@@ -21,11 +21,11 @@ tlsn-attestation = { workspace = true }
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-deap = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
tlsn-mpc-tls = { workspace = true }
|
||||
tlsn-cipher = { workspace = true }
|
||||
|
||||
futures-plex = { workspace = true }
|
||||
serio = { workspace = true, features = ["compat"] }
|
||||
uid-mux = { workspace = true, features = ["serio"] }
|
||||
web-spawn = { workspace = true, optional = true }
|
||||
@@ -57,6 +57,7 @@ serde = { workspace = true, features = ["derive"] }
|
||||
ghash = { workspace = true }
|
||||
semver = { workspace = true, features = ["serde"] }
|
||||
once_cell = { workspace = true }
|
||||
pin-project-lite = { workspace = true }
|
||||
rangeset = { workspace = true }
|
||||
webpki-roots = { workspace = true }
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
//! Execution context.
|
||||
|
||||
use mpz_common::context::Multithread;
|
||||
|
||||
use crate::mux::MuxControl;
|
||||
|
||||
/// Maximum concurrency for multi-threaded context.
|
||||
pub(crate) const MAX_CONCURRENCY: usize = 8;
|
||||
|
||||
/// Builds a multi-threaded context with the given muxer.
|
||||
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
|
||||
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
|
||||
|
||||
#[cfg(all(feature = "web", target_arch = "wasm32"))]
|
||||
let builder = builder.spawn_handler(|f| {
|
||||
let _ = web_spawn::spawn(f);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
builder.build().unwrap()
|
||||
}
|
||||
@@ -4,7 +4,6 @@
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub(crate) mod context;
|
||||
pub(crate) mod ghash;
|
||||
pub(crate) mod map;
|
||||
pub(crate) mod mpz;
|
||||
@@ -13,6 +12,7 @@ pub(crate) mod mux;
|
||||
pub mod prover;
|
||||
pub(crate) mod tag;
|
||||
pub(crate) mod transcript_internal;
|
||||
pub(crate) mod utils;
|
||||
pub mod verifier;
|
||||
|
||||
pub use tlsn_attestation as attestation;
|
||||
@@ -27,6 +27,8 @@ pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
|
||||
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
|
||||
});
|
||||
|
||||
const BUF_CAP: usize = 16 * 1024;
|
||||
|
||||
/// The party's role in the TLSN protocol.
|
||||
///
|
||||
/// A Notary is classified as a Verifier.
|
||||
|
||||
@@ -21,6 +21,20 @@ impl<T> RangeMap<T>
|
||||
where
|
||||
T: Item,
|
||||
{
|
||||
pub(crate) fn new(map: Vec<(usize, T)>) -> Self {
|
||||
let mut pos = 0;
|
||||
for (idx, item) in &map {
|
||||
assert!(
|
||||
*idx >= pos,
|
||||
"items must be sorted by index and non-overlapping"
|
||||
);
|
||||
|
||||
pos = *idx + item.length();
|
||||
}
|
||||
|
||||
Self { map }
|
||||
}
|
||||
|
||||
/// Returns `true` if the map is empty.
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
@@ -33,6 +47,11 @@ where
|
||||
.map(|(idx, item)| *idx..*idx + item.length())
|
||||
}
|
||||
|
||||
/// Returns the length of the map.
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.map.iter().map(|(_, item)| item.length()).sum()
|
||||
}
|
||||
|
||||
pub(crate) fn iter(&self) -> impl Iterator<Item = (Range<usize>, &T)> {
|
||||
self.map
|
||||
.iter()
|
||||
|
||||
@@ -6,6 +6,11 @@ use mpz_core::Block;
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_garble_core::Delta;
|
||||
use mpz_memory_core::{
|
||||
Vector,
|
||||
binary::U8,
|
||||
correlated::{Key, Mac},
|
||||
};
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
use mpz_ot::cot::{DerandCOTReceiver, DerandCOTSender};
|
||||
use mpz_ot::{
|
||||
@@ -19,6 +24,8 @@ use tlsn_core::config::tls_commit::mpc::{MpcTlsConfig, NetworkSetting};
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::transcript_internal::commit::encoding::{KeyStore, MacStore};
|
||||
|
||||
#[cfg(not(tlsn_insecure))]
|
||||
pub(crate) type ProverMpc =
|
||||
Garbler<DerandCOTSender<SharedRCOTSender<kos::Sender<co::Receiver>, Block>>>;
|
||||
@@ -186,3 +193,41 @@ pub(crate) fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>
|
||||
.translate(keys.server_write_mac_key)
|
||||
.expect("VM memory should be consistent");
|
||||
}
|
||||
|
||||
impl<T> KeyStore for Verifier<T> {
|
||||
fn delta(&self) -> &Delta {
|
||||
self.delta()
|
||||
}
|
||||
|
||||
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
|
||||
self.get_keys(data).ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MacStore for Prover<T> {
|
||||
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
|
||||
self.get_macs(data).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(tlsn_insecure)]
|
||||
mod insecure {
|
||||
use super::*;
|
||||
use mpz_ideal_vm::IdealVm;
|
||||
|
||||
impl KeyStore for IdealVm {
|
||||
fn delta(&self) -> &Delta {
|
||||
unimplemented!("encodings not supported in insecure mode")
|
||||
}
|
||||
|
||||
fn get_keys(&self, _data: Vector<U8>) -> Option<&[Key]> {
|
||||
unimplemented!("encodings not supported in insecure mode")
|
||||
}
|
||||
}
|
||||
|
||||
impl MacStore for IdealVm {
|
||||
fn get_macs(&self, _data: Vector<U8>) -> Option<&[Mac]> {
|
||||
unimplemented!("encodings not supported in insecure mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,31 +1,38 @@
|
||||
//! Prover.
|
||||
|
||||
mod client;
|
||||
mod conn;
|
||||
mod control;
|
||||
mod error;
|
||||
mod future;
|
||||
mod prove;
|
||||
pub mod state;
|
||||
|
||||
pub use conn::TlsConnection;
|
||||
pub use control::ProverControl;
|
||||
pub use error::ProverError;
|
||||
pub use future::ProverFuture;
|
||||
pub use tlsn_core::ProverOutput;
|
||||
|
||||
use crate::{
|
||||
Role,
|
||||
context::build_mt_context,
|
||||
BUF_CAP, Role,
|
||||
mpz::{ProverDeps, build_prover_deps, translate_keys},
|
||||
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
|
||||
mux::attach_mux,
|
||||
tag::verify_tags,
|
||||
prover::{
|
||||
client::{MpcTlsClient, TlsOutput},
|
||||
state::ConnectedProj,
|
||||
},
|
||||
utils::{CopyIo, await_with_copy_io, build_mt_context},
|
||||
};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpc_tls::LeaderCtrl;
|
||||
use mpz_vm_core::prelude::*;
|
||||
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, FutureExt, TryFutureExt, ready};
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tls_client::{ClientConnection, ServerName as TlsServerName};
|
||||
use tls_client_async::{TlsConnection, bind_client};
|
||||
use tlsn_core::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
@@ -36,10 +43,9 @@ use tlsn_core::{
|
||||
connection::{HandshakeData, ServerName},
|
||||
transcript::{TlsTranscript, Transcript},
|
||||
};
|
||||
use tracing::{Span, debug, info_span, instrument};
|
||||
use webpki::anchor_from_trusted_cert;
|
||||
|
||||
use tracing::{Instrument, Span, debug, info, info_span, instrument};
|
||||
|
||||
/// A prover instance.
|
||||
#[derive(Debug)]
|
||||
pub struct Prover<T: state::ProverState = state::Initialized> {
|
||||
@@ -71,14 +77,27 @@ impl Prover<state::Initialized> {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The TLS commitment configuration.
|
||||
/// * `socket` - The socket to the TLS verifier.
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
self,
|
||||
config: TlsCommitConfig,
|
||||
socket: S,
|
||||
verifier_io: S,
|
||||
) -> Result<Prover<state::CommitAccepted>, ProverError> {
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
|
||||
let (duplex_a, mut duplex_b) = futures_plex::duplex(BUF_CAP);
|
||||
let fut = Box::pin(self.commit_inner(config, duplex_a).fuse());
|
||||
let mut prover = await_with_copy_io(fut, verifier_io, &mut duplex_b).await?;
|
||||
|
||||
prover.state.verifier_io = Some(duplex_b);
|
||||
Ok(prover)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn commit_inner<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
self,
|
||||
config: TlsCommitConfig,
|
||||
verifier_io: S,
|
||||
) -> Result<Prover<state::CommitAccepted>, ProverError> {
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(verifier_io, Role::Prover);
|
||||
let mut mt = build_mt_context(mux_ctrl.clone());
|
||||
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
|
||||
|
||||
@@ -122,6 +141,7 @@ impl Prover<state::Initialized> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::CommitAccepted {
|
||||
verifier_io: None,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
@@ -133,31 +153,29 @@ impl Prover<state::Initialized> {
|
||||
}
|
||||
|
||||
impl Prover<state::CommitAccepted> {
|
||||
/// Connects to the server using the provided socket.
|
||||
/// Sets up the prover with the client configuration.
|
||||
///
|
||||
/// Returns a handle to the TLS connection, a future which returns the
|
||||
/// prover once the connection is closed and the TLS transcript is
|
||||
/// committed.
|
||||
/// Returns a set up prover, and a [`TlsConnection`] which can be used to
|
||||
/// read and write bytes from/to the server.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The TLS client configuration.
|
||||
/// * `socket` - The socket to the server.
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
pub fn setup(
|
||||
self,
|
||||
config: TlsClientConfig,
|
||||
socket: S,
|
||||
) -> Result<(TlsConnection, ProverFuture), ProverError> {
|
||||
) -> Result<(TlsConnection, Prover<state::Setup>), ProverError> {
|
||||
let state::CommitAccepted {
|
||||
verifier_io,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
keys,
|
||||
vm,
|
||||
..
|
||||
} = self.state;
|
||||
|
||||
let decrypt = mpc_tls.is_decrypting();
|
||||
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
|
||||
|
||||
let ServerName::Dns(server_name) = config.server_name();
|
||||
@@ -202,95 +220,296 @@ impl Prover<state::CommitAccepted> {
|
||||
)
|
||||
.map_err(ProverError::config)?;
|
||||
|
||||
let (conn, conn_fut) = bind_client(socket, client);
|
||||
let span = self.span.clone();
|
||||
|
||||
let fut = Box::pin({
|
||||
let span = self.span.clone();
|
||||
let mpc_ctrl = mpc_ctrl.clone();
|
||||
async move {
|
||||
let conn_fut = async {
|
||||
mux_fut
|
||||
.poll_with(conn_fut.map_err(ProverError::from))
|
||||
.await?;
|
||||
let mpc_tls = MpcTlsClient::new(
|
||||
Box::new(mpc_fut.map_err(ProverError::from)),
|
||||
keys,
|
||||
vm,
|
||||
span,
|
||||
mpc_ctrl,
|
||||
client,
|
||||
decrypt,
|
||||
);
|
||||
|
||||
mpc_ctrl.stop().await?;
|
||||
|
||||
Ok::<_, ProverError>(())
|
||||
};
|
||||
|
||||
info!("starting MPC-TLS");
|
||||
|
||||
let (_, (mut ctx, tls_transcript)) = futures::try_join!(
|
||||
conn_fut,
|
||||
mpc_fut.in_current_span().map_err(ProverError::from)
|
||||
)?;
|
||||
|
||||
info!("finished MPC-TLS");
|
||||
|
||||
{
|
||||
let mut vm = vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
debug!("finalizing mpc");
|
||||
|
||||
// Finalize DEAP.
|
||||
mux_fut
|
||||
.poll_with(vm.finalize(&mut ctx))
|
||||
.await
|
||||
.map_err(ProverError::mpc)?;
|
||||
|
||||
debug!("mpc finalized");
|
||||
}
|
||||
|
||||
// Pull out ZK VM.
|
||||
let (_, mut vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut vm,
|
||||
(keys.server_write_key, keys.server_write_iv),
|
||||
keys.server_write_mac_key,
|
||||
*tls_transcript.version(),
|
||||
tls_transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
|
||||
.await?;
|
||||
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
server_name: config.server_name().clone(),
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
},
|
||||
})
|
||||
}
|
||||
.instrument(span)
|
||||
});
|
||||
|
||||
Ok((
|
||||
conn,
|
||||
ProverFuture {
|
||||
fut,
|
||||
ctrl: ProverControl { mpc_ctrl },
|
||||
let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Setup {
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
server_name: config.server_name().clone(),
|
||||
tls_client: Box::new(mpc_tls),
|
||||
client_io: duplex_a,
|
||||
verifier_io,
|
||||
},
|
||||
))
|
||||
};
|
||||
|
||||
let conn = TlsConnection::new(duplex_b);
|
||||
Ok((conn, prover))
|
||||
}
|
||||
}
|
||||
|
||||
impl Prover<state::Setup> {
|
||||
/// Returns a handle to control the prover.
|
||||
pub fn handle(&self) -> ProverControl {
|
||||
let handle = self.state.tls_client.handle();
|
||||
ProverControl { handle }
|
||||
}
|
||||
|
||||
/// Attaches IO to the prover.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `server_io` - The IO to the server.
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub fn connect<S, T>(self, server_io: S, verifier_io: T) -> Prover<state::Connected<S, T>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
let (client_to_server, server_to_client) = futures_plex::duplex(BUF_CAP);
|
||||
|
||||
Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Connected {
|
||||
verifier_io: self.state.verifier_io,
|
||||
mux_ctrl: self.state.mux_ctrl,
|
||||
mux_fut: self.state.mux_fut,
|
||||
server_name: self.state.server_name,
|
||||
tls_client: self.state.tls_client,
|
||||
client_io: self.state.client_io,
|
||||
output: None,
|
||||
server_socket: server_io,
|
||||
verifier_socket: verifier_io,
|
||||
tls_client_to_server_buf: client_to_server,
|
||||
server_to_tls_client_buf: server_to_client,
|
||||
client_closed: false,
|
||||
server_closed: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// This is a convenience method which attaches IO, runs the prover and
|
||||
/// returns a committed prover together with the IO.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `server_io` - The IO to the server.
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub async fn run<S, T>(
|
||||
self,
|
||||
mut server_io: S,
|
||||
mut verifier_io: T,
|
||||
) -> Result<(Prover<state::Committed>, S, T), ProverError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let mut prover = self.connect(&mut server_io, &mut verifier_io);
|
||||
(&mut prover).await?;
|
||||
|
||||
let prover = prover.finish()?;
|
||||
Ok((prover, server_io, verifier_io))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Future for Prover<state::Connected<S, T>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
type Output = Result<(), ProverError>;
|
||||
|
||||
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut state = Pin::new(&mut self.state).project();
|
||||
|
||||
loop {
|
||||
let mut progress = false;
|
||||
|
||||
if state.output.is_none()
|
||||
&& let Poll::Ready(output) = state.tls_client.poll(cx)?
|
||||
{
|
||||
*state.output = Some(output);
|
||||
}
|
||||
|
||||
progress |= Self::io_client_conn(&mut state, cx)?;
|
||||
progress |= Self::io_client_server(&mut state, cx)?;
|
||||
progress |= Self::io_client_verifier(&mut state, cx)?;
|
||||
|
||||
_ = state.mux_fut.poll_unpin(cx)?;
|
||||
|
||||
if *state.server_closed && state.output.is_some() {
|
||||
ready!(state.client_io.poll_close(cx))?;
|
||||
ready!(state.server_socket.poll_close(cx))?;
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
} else if !progress {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Prover<state::Connected<S, T>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
fn io_client_conn(
|
||||
state: &mut ConnectedProj<S, T>,
|
||||
cx: &mut Context,
|
||||
) -> Result<bool, ProverError> {
|
||||
let mut progress = false;
|
||||
|
||||
// tls_conn -> tls_client
|
||||
if state.tls_client.wants_write()
|
||||
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_read(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
|
||||
{
|
||||
if !buf.is_empty() {
|
||||
let write = state.tls_client.write(buf)?;
|
||||
if write > 0 {
|
||||
progress = true;
|
||||
simplex.advance(write);
|
||||
}
|
||||
} else if !*state.client_closed && !*state.server_closed {
|
||||
progress = true;
|
||||
*state.client_closed = true;
|
||||
state.tls_client.client_close()?;
|
||||
}
|
||||
}
|
||||
|
||||
// tls_client -> tls_conn
|
||||
if state.tls_client.wants_read()
|
||||
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_write(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
|
||||
&& let read = state.tls_client.read(buf)?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
simplex.advance_mut(read);
|
||||
}
|
||||
Ok(progress)
|
||||
}
|
||||
|
||||
fn io_client_server(
|
||||
state: &mut ConnectedProj<S, T>,
|
||||
cx: &mut Context,
|
||||
) -> Result<bool, ProverError> {
|
||||
let mut progress = false;
|
||||
|
||||
// server_socket -> buf
|
||||
if let Poll::Ready(write) = state
|
||||
.server_to_tls_client_buf
|
||||
.poll_write_from(cx, state.server_socket.as_mut())?
|
||||
{
|
||||
if write > 0 {
|
||||
progress = true;
|
||||
} else if !*state.server_closed {
|
||||
progress = true;
|
||||
*state.server_closed = true;
|
||||
state.tls_client.server_close()?;
|
||||
}
|
||||
}
|
||||
|
||||
// buf -> tls_client
|
||||
if state.tls_client.wants_read_tls()
|
||||
&& let Poll::Ready(mut simplex) =
|
||||
state.tls_client_to_server_buf.as_mut().poll_lock_read(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
|
||||
&& let read = state.tls_client.read_tls(buf)?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
simplex.advance(read);
|
||||
}
|
||||
|
||||
// tls_client -> buf
|
||||
if state.tls_client.wants_write_tls()
|
||||
&& let Poll::Ready(mut simplex) =
|
||||
state.tls_client_to_server_buf.as_mut().poll_lock_write(cx)
|
||||
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
|
||||
&& let write = state.tls_client.write_tls(buf)?
|
||||
&& write > 0
|
||||
{
|
||||
progress = true;
|
||||
simplex.advance_mut(write);
|
||||
}
|
||||
|
||||
// buf -> server_socket
|
||||
if let Poll::Ready(read) = state
|
||||
.server_to_tls_client_buf
|
||||
.poll_read_to(cx, state.server_socket.as_mut())?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
}
|
||||
|
||||
Ok(progress)
|
||||
}
|
||||
|
||||
fn io_client_verifier(
|
||||
state: &mut ConnectedProj<S, T>,
|
||||
cx: &mut Context,
|
||||
) -> Result<bool, ProverError> {
|
||||
let mut progress = false;
|
||||
|
||||
let verifier_io = Pin::new(
|
||||
(*state.verifier_io)
|
||||
.as_mut()
|
||||
.expect("verifier io should be available"),
|
||||
);
|
||||
|
||||
// mux -> verifier_socket
|
||||
if let Poll::Ready(read) = verifier_io.poll_read_to(cx, state.verifier_socket.as_mut())?
|
||||
&& read > 0
|
||||
{
|
||||
progress = true;
|
||||
}
|
||||
|
||||
// verifier_socket -> mux
|
||||
if let Poll::Ready(write) =
|
||||
verifier_io.poll_write_from(cx, state.verifier_socket.as_mut())?
|
||||
&& write > 0
|
||||
{
|
||||
progress = true;
|
||||
}
|
||||
|
||||
Ok(progress)
|
||||
}
|
||||
|
||||
/// Returns a committed prover after the TLS session has completed.
|
||||
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
|
||||
let TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
} = self.state.output.ok_or(ProverError::state(
|
||||
"prover has not yet closed the connection",
|
||||
))?;
|
||||
|
||||
let prover = Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
verifier_io: self.state.verifier_io,
|
||||
mux_ctrl: self.state.mux_ctrl,
|
||||
mux_fut: self.state.mux_fut,
|
||||
ctx,
|
||||
vm,
|
||||
server_name: self.state.server_name,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(prover)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,8 +529,30 @@ impl Prover<state::Committed> {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The disclosure configuration.
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
pub async fn prove<S>(
|
||||
&mut self,
|
||||
config: &ProveConfig,
|
||||
verifier_io: S,
|
||||
) -> Result<ProverOutput, ProverError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
let mut duplex = self
|
||||
.state
|
||||
.verifier_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = Box::pin(self.prove_inner(config).fuse());
|
||||
let output = await_with_copy_io(fut, verifier_io, &mut duplex).await?;
|
||||
|
||||
self.state.verifier_io = Some(duplex);
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
|
||||
async fn prove_inner(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
|
||||
let state::Committed {
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -364,44 +605,31 @@ impl Prover<state::Committed> {
|
||||
}
|
||||
|
||||
/// Closes the connection with the verifier.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `verifier_io` - The IO to the TLS verifier.
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn close(self) -> Result<(), ProverError> {
|
||||
pub async fn close<S>(mut self, mut verifier_io: S) -> Result<(), ProverError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
let state::Committed {
|
||||
mux_ctrl, mux_fut, ..
|
||||
} = self.state;
|
||||
|
||||
// Wait for the verifier to correctly close the connection.
|
||||
if !mux_fut.is_complete() {
|
||||
mux_ctrl.close();
|
||||
mux_fut.await?;
|
||||
}
|
||||
let mut duplex = self
|
||||
.state
|
||||
.verifier_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
mux_ctrl.close();
|
||||
let copy = CopyIo::new(&mut verifier_io, &mut duplex).map_err(ProverError::from);
|
||||
futures::try_join!(mux_fut.map_err(ProverError::from), copy)?;
|
||||
|
||||
// Wait for the verifier to finish closing.
|
||||
verifier_io.read_exact(&mut [0_u8; 5]).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A controller for the prover.
|
||||
#[derive(Clone)]
|
||||
pub struct ProverControl {
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
}
|
||||
|
||||
impl ProverControl {
|
||||
/// Defers decryption of data from the server until the server has closed
|
||||
/// the connection.
|
||||
///
|
||||
/// This is a performance optimization which will significantly reduce the
|
||||
/// amount of upload bandwidth used by the prover.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// * The prover may need to close the connection to the server in order for
|
||||
/// it to close the connection on its end. If neither the prover or server
|
||||
/// close the connection this will cause a deadlock.
|
||||
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
|
||||
self.mpc_ctrl
|
||||
.defer_decryption()
|
||||
.await
|
||||
.map_err(ProverError::from)
|
||||
}
|
||||
}
|
||||
|
||||
93
crates/tlsn/src/prover/client.rs
Normal file
93
crates/tlsn/src/prover/client.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
//! Provides a TLS client.
|
||||
|
||||
use crate::{mpz::ProverZk, prover::control::ControlError};
|
||||
use mpc_tls::SessionKeys;
|
||||
use std::{
|
||||
sync::mpsc::{Sender, SyncSender, sync_channel},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tlsn_core::transcript::{TlsTranscript, Transcript};
|
||||
|
||||
mod mpc;
|
||||
|
||||
pub(crate) use mpc::MpcTlsClient;
|
||||
|
||||
/// TLS client for MPC and proxy-based TLS implementations.
|
||||
pub(crate) trait TlsClient {
|
||||
type Error: std::error::Error + Send + Sync + Unpin + 'static;
|
||||
|
||||
/// Returns `true` if the client wants to read TLS data from the server.
|
||||
fn wants_read_tls(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write TLS data to the server.
|
||||
fn wants_write_tls(&self) -> bool;
|
||||
|
||||
/// Reads TLS data from the server.
|
||||
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes TLS data for the server into the provided buffer.
|
||||
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Returns `true` if the client wants to read plaintext data.
|
||||
fn wants_read(&self) -> bool;
|
||||
|
||||
/// Returns `true` if the client wants to write plaintext data.
|
||||
fn wants_write(&self) -> bool;
|
||||
|
||||
/// Reads plaintext data from the server into the provided buffer.
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Writes plaintext data to be sent to the server.
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Client closes the connection.
|
||||
fn client_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Server closes the connection.
|
||||
fn server_close(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Returns a handle to control the client.
|
||||
fn handle(&self) -> ClientHandle;
|
||||
|
||||
/// Polls the client to make progress.
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ClientHandle {
|
||||
sender: Sender<Command>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) enum Command {
|
||||
IsDecrypting(SyncSender<bool>),
|
||||
SetDecrypt(bool),
|
||||
ClientClose,
|
||||
ServerClose,
|
||||
}
|
||||
|
||||
impl ClientHandle {
|
||||
pub(crate) fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
|
||||
self.sender
|
||||
.send(Command::SetDecrypt(enable))
|
||||
.map_err(|_| ControlError)
|
||||
}
|
||||
|
||||
pub(crate) fn is_decrypting(&self) -> bool {
|
||||
let (sender, receiver) = sync_channel(1);
|
||||
let Ok(_) = self.sender.send(Command::IsDecrypting(sender)) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
receiver.recv().unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Output of a TLS session.
|
||||
pub(crate) struct TlsOutput {
|
||||
pub(crate) ctx: mpz_common::Context,
|
||||
pub(crate) vm: ProverZk,
|
||||
pub(crate) keys: SessionKeys,
|
||||
pub(crate) tls_transcript: TlsTranscript,
|
||||
pub(crate) transcript: Transcript,
|
||||
}
|
||||
503
crates/tlsn/src/prover/client/mpc.rs
Normal file
503
crates/tlsn/src/prover/client/mpc.rs
Normal file
@@ -0,0 +1,503 @@
|
||||
//! Implementation of an MPC-TLS client.
|
||||
|
||||
use crate::{
|
||||
mpz::{ProverMpc, ProverZk},
|
||||
prover::{
|
||||
ProverError,
|
||||
client::{ClientHandle, Command, TlsClient, TlsOutput},
|
||||
},
|
||||
tag::verify_tags,
|
||||
};
|
||||
use futures::{Future, FutureExt};
|
||||
use mpc_tls::{LeaderCtrl, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use mpz_vm_core::Execute;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{
|
||||
Arc,
|
||||
mpsc::{Receiver, Sender, channel},
|
||||
},
|
||||
task::Poll,
|
||||
};
|
||||
use tls_client::ClientConnection;
|
||||
use tlsn_core::transcript::TlsTranscript;
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{Span, debug, instrument, trace, warn};
|
||||
|
||||
pub(crate) type MpcFuture =
|
||||
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
|
||||
|
||||
type FinalizeFuture =
|
||||
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
|
||||
|
||||
pub(crate) struct MpcTlsClient {
|
||||
sender: Sender<Command>,
|
||||
state: State,
|
||||
decrypt: bool,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Start {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
receiver: Receiver<Command>,
|
||||
},
|
||||
Active {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
receiver: Receiver<Command>,
|
||||
},
|
||||
Busy {
|
||||
mpc: Pin<MpcFuture>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
receiver: Receiver<Command>,
|
||||
},
|
||||
MpcStop {
|
||||
mpc: Pin<MpcFuture>,
|
||||
inner: Box<InnerState>,
|
||||
},
|
||||
CloseBusy {
|
||||
mpc: Pin<MpcFuture>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
},
|
||||
Finishing {
|
||||
ctx: Context,
|
||||
transcript: Box<TlsTranscript>,
|
||||
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
|
||||
},
|
||||
Finalizing {
|
||||
fut: Pin<FinalizeFuture>,
|
||||
},
|
||||
Finished,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl MpcTlsClient {
|
||||
pub(crate) fn new(
|
||||
mpc: MpcFuture,
|
||||
keys: SessionKeys,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
span: Span,
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
tls: ClientConnection,
|
||||
decrypt: bool,
|
||||
) -> Self {
|
||||
let inner = InnerState {
|
||||
span,
|
||||
tls,
|
||||
vm,
|
||||
keys,
|
||||
mpc_ctrl,
|
||||
client_closed: false,
|
||||
mpc_stopped: false,
|
||||
};
|
||||
let (sender, receiver) = channel();
|
||||
|
||||
Self {
|
||||
sender,
|
||||
decrypt,
|
||||
state: State::Start {
|
||||
receiver,
|
||||
mpc: Box::into_pin(mpc),
|
||||
inner: Box::new(inner),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
|
||||
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &mut self.state {
|
||||
Some(&mut inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_client(&self) -> Option<&ClientConnection> {
|
||||
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &self.state {
|
||||
Some(&inner.tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsClient for MpcTlsClient {
|
||||
type Error = ProverError;
|
||||
|
||||
fn wants_read_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_read()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write_tls(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
client.wants_write()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_read()
|
||||
{
|
||||
client.read_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& client.wants_write()
|
||||
{
|
||||
client.write_tls(&mut buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_read(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.plaintext_is_empty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn wants_write(&self) -> bool {
|
||||
if let Some(client) = self.inner_client() {
|
||||
!client.sendable_plaintext_is_full()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.plaintext_is_empty()
|
||||
{
|
||||
client.read_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
|
||||
if let Some(client) = self.inner_client_mut()
|
||||
&& !client.sendable_plaintext_is_full()
|
||||
{
|
||||
client.write_plaintext(buf).map_err(ProverError::from)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn client_close(&mut self) -> Result<(), Self::Error> {
|
||||
self.sender
|
||||
.send(Command::ClientClose)
|
||||
.map_err(|_| ProverError::state("unable to close connection clientside"))
|
||||
}
|
||||
|
||||
fn server_close(&mut self) -> Result<(), Self::Error> {
|
||||
self.sender
|
||||
.send(Command::ServerClose)
|
||||
.map_err(|_| ProverError::state("unable to close connection serverside"))
|
||||
}
|
||||
|
||||
fn handle(&self) -> ClientHandle {
|
||||
ClientHandle {
|
||||
sender: self.sender.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
|
||||
match std::mem::replace(&mut self.state, State::Error) {
|
||||
State::Start {
|
||||
mpc,
|
||||
inner,
|
||||
receiver,
|
||||
} => {
|
||||
trace!("inner client is starting");
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.start()),
|
||||
receiver,
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Active {
|
||||
mpc,
|
||||
inner,
|
||||
receiver,
|
||||
} => {
|
||||
trace!("inner client is active");
|
||||
|
||||
if !inner.tls.is_handshaking()
|
||||
&& let Ok(cmd) = receiver.try_recv()
|
||||
{
|
||||
match cmd {
|
||||
Command::ClientClose => {
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.client_close()),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
Command::ServerClose => {
|
||||
std::mem::drop(receiver);
|
||||
self.state = State::CloseBusy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.server_close()),
|
||||
};
|
||||
}
|
||||
Command::SetDecrypt(enable) => {
|
||||
self.decrypt = enable;
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.set_decrypt(enable)),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
Command::IsDecrypting(sender) => {
|
||||
_ = sender.send(self.decrypt);
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.run()),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.state = State::Busy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.run()),
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
State::Busy {
|
||||
mut mpc,
|
||||
mut fut,
|
||||
receiver,
|
||||
} => {
|
||||
trace!("inner client is busy");
|
||||
|
||||
let mpc_poll = mpc.as_mut().poll(cx)?;
|
||||
|
||||
assert!(
|
||||
matches!(mpc_poll, Poll::Pending),
|
||||
"mpc future should not be finished here"
|
||||
);
|
||||
|
||||
match fut.as_mut().poll(cx)? {
|
||||
Poll::Ready(inner) => {
|
||||
self.state = State::Active {
|
||||
mpc,
|
||||
inner,
|
||||
receiver,
|
||||
};
|
||||
}
|
||||
Poll::Pending => self.state = State::Busy { mpc, fut, receiver },
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
State::MpcStop { mpc, inner } => {
|
||||
trace!("inner client is stopping mpc");
|
||||
self.state = State::CloseBusy {
|
||||
mpc,
|
||||
fut: Box::pin(inner.stop()),
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
State::CloseBusy { mut mpc, mut fut } => {
|
||||
trace!("inner client is busy closing");
|
||||
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
|
||||
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, transcript)),
|
||||
};
|
||||
self.poll(cx)
|
||||
}
|
||||
(Poll::Ready(inner), Poll::Pending) => {
|
||||
self.state = State::MpcStop { mpc, inner };
|
||||
Poll::Pending
|
||||
}
|
||||
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
|
||||
self.state = State::Finishing {
|
||||
ctx,
|
||||
transcript: Box::new(transcript),
|
||||
fut,
|
||||
};
|
||||
Poll::Pending
|
||||
}
|
||||
(Poll::Pending, Poll::Pending) => {
|
||||
self.state = State::CloseBusy { mpc, fut };
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
State::Finishing {
|
||||
ctx,
|
||||
transcript,
|
||||
mut fut,
|
||||
} => {
|
||||
trace!("inner client is finishing");
|
||||
if let Poll::Ready(inner) = fut.poll_unpin(cx)? {
|
||||
self.state = State::Finalizing {
|
||||
fut: Box::pin(inner.finalize(ctx, *transcript)),
|
||||
};
|
||||
self.poll(cx)
|
||||
} else {
|
||||
self.state = State::Finishing {
|
||||
ctx,
|
||||
transcript,
|
||||
fut,
|
||||
};
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
|
||||
Poll::Ready(output) => {
|
||||
let (inner, ctx, tls_transcript) = output?;
|
||||
let InnerState { vm, keys, .. } = inner;
|
||||
|
||||
let transcript = tls_transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
|
||||
let (_, vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
let output = TlsOutput {
|
||||
ctx,
|
||||
vm,
|
||||
keys,
|
||||
tls_transcript,
|
||||
transcript,
|
||||
};
|
||||
|
||||
self.state = State::Finished;
|
||||
Poll::Ready(Ok(output))
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.state = State::Finalizing { fut };
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
State::Finished => Poll::Ready(Err(ProverError::state(
|
||||
"mpc tls client polled again in finished state",
|
||||
))),
|
||||
State::Error => {
|
||||
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerState {
|
||||
span: Span,
|
||||
tls: ClientConnection,
|
||||
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
|
||||
keys: SessionKeys,
|
||||
mpc_ctrl: LeaderCtrl,
|
||||
client_closed: bool,
|
||||
mpc_stopped: bool,
|
||||
}
|
||||
|
||||
impl InnerState {
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.start().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
|
||||
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
|
||||
self.mpc_ctrl.enable_decryption(enable).await?;
|
||||
self.run().await
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
if !self.client_closed {
|
||||
debug!("sending close notify");
|
||||
if let Err(e) = self.tls.send_close_notify().await {
|
||||
warn!("failed to send close_notify to server: {}", e);
|
||||
}
|
||||
self.client_closed = true;
|
||||
}
|
||||
self.run().await
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
self.tls.server_closed().await?;
|
||||
debug!("closed connection serverside");
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn stop(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
|
||||
self.tls.process_new_packets().await?;
|
||||
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
|
||||
self.mpc_ctrl.stop().await?;
|
||||
self.mpc_stopped = true;
|
||||
debug!("stopped mpc");
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
|
||||
async fn finalize(
|
||||
self,
|
||||
mut ctx: Context,
|
||||
transcript: TlsTranscript,
|
||||
) -> Result<(Self, Context, TlsTranscript), ProverError> {
|
||||
{
|
||||
let mut vm = self.vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
// Finalize DEAP.
|
||||
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
|
||||
|
||||
debug!("mpc finalized");
|
||||
|
||||
// Pull out ZK VM.
|
||||
let mut zk = vm.zk();
|
||||
|
||||
// Prove tag verification of received records.
|
||||
// The prover drops the proof output.
|
||||
let _ = verify_tags(
|
||||
&mut *zk,
|
||||
(self.keys.server_write_key, self.keys.server_write_iv),
|
||||
self.keys.server_write_mac_key,
|
||||
*transcript.version(),
|
||||
transcript.recv().to_vec(),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
debug!("verified tags from server");
|
||||
|
||||
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
|
||||
}
|
||||
|
||||
debug!("MPC-TLS done");
|
||||
Ok((self, ctx, transcript))
|
||||
}
|
||||
}
|
||||
66
crates/tlsn/src/prover/conn.rs
Normal file
66
crates/tlsn/src/prover/conn.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use futures_plex::DuplexStream;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// A TLS connection to a server.
|
||||
///
|
||||
/// This type implements [`AsyncRead`] and [`AsyncWrite`] and can be used to
|
||||
/// communicate with a server using TLS.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This connection is closed on a best-effort basis if this is dropped. To
|
||||
/// ensure a clean close, you should call
|
||||
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
|
||||
/// connection.
|
||||
pub struct TlsConnection {
|
||||
duplex: DuplexStream,
|
||||
}
|
||||
|
||||
impl TlsConnection {
|
||||
pub(crate) fn new(duplex: DuplexStream) -> Self {
|
||||
Self { duplex }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TlsConnection {
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = futures::executor::block_on(self.duplex.close()) {
|
||||
tracing::error!("error closing connection: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TlsConnection {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
let duplex = Pin::new(&mut self.duplex);
|
||||
duplex.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsConnection {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
let duplex = Pin::new(&mut self.duplex);
|
||||
duplex.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
let duplex = Pin::new(&mut self.duplex);
|
||||
duplex.poll_close(cx)
|
||||
}
|
||||
}
|
||||
29
crates/tlsn/src/prover/control.rs
Normal file
29
crates/tlsn/src/prover/control.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::prover::client::ClientHandle;
|
||||
|
||||
/// A controller for the prover.
|
||||
///
|
||||
/// Can be used to control the decryption of server traffic.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ProverControl {
|
||||
pub(crate) handle: ClientHandle,
|
||||
}
|
||||
|
||||
impl ProverControl {
|
||||
/// Returns whether the prover is decrypting the server traffic.
|
||||
pub fn is_decrypting(&self) -> bool {
|
||||
self.handle.is_decrypting()
|
||||
}
|
||||
|
||||
/// Enables or disables the decryption of server traffic.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `enable` - If decryption should be enabled or disabled.
|
||||
pub fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
|
||||
self.handle.enable_decryption(enable)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Unable to send control command to prover.")]
|
||||
pub struct ControlError;
|
||||
@@ -2,6 +2,8 @@ use std::{error::Error, fmt};
|
||||
|
||||
use mpc_tls::MpcTlsError;
|
||||
|
||||
use crate::transcript_internal::commit::encoding::EncodingError;
|
||||
|
||||
/// Error for [`Prover`](crate::prover::Prover).
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct ProverError {
|
||||
@@ -47,6 +49,13 @@ impl ProverError {
|
||||
{
|
||||
Self::new(ErrorKind::Commit, source)
|
||||
}
|
||||
|
||||
pub(crate) fn state<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::State, source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -56,6 +65,7 @@ enum ErrorKind {
|
||||
Zk,
|
||||
Config,
|
||||
Commit,
|
||||
State,
|
||||
}
|
||||
|
||||
impl fmt::Display for ProverError {
|
||||
@@ -68,6 +78,7 @@ impl fmt::Display for ProverError {
|
||||
ErrorKind::Zk => f.write_str("zk error")?,
|
||||
ErrorKind::Config => f.write_str("config error")?,
|
||||
ErrorKind::Commit => f.write_str("commit error")?,
|
||||
ErrorKind::State => f.write_str("state error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
@@ -84,8 +95,8 @@ impl From<std::io::Error> for ProverError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tls_client_async::ConnectionError> for ProverError {
|
||||
fn from(e: tls_client_async::ConnectionError) -> Self {
|
||||
impl From<tls_client::Error> for ProverError {
|
||||
fn from(e: tls_client::Error) -> Self {
|
||||
Self::new(ErrorKind::Io, e)
|
||||
}
|
||||
}
|
||||
@@ -107,3 +118,9 @@ impl From<MpcTlsError> for ProverError {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodingError> for ProverError {
|
||||
fn from(e: EncodingError) -> Self {
|
||||
Self::new(ErrorKind::Commit, e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
//! This module collects futures which are used by the [Prover].
|
||||
|
||||
use super::{Prover, ProverControl, ProverError, state};
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Prover future which must be polled for the TLS connection to make progress.
|
||||
pub struct ProverFuture {
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fut: Pin<
|
||||
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
|
||||
>,
|
||||
pub(crate) ctrl: ProverControl,
|
||||
}
|
||||
|
||||
impl ProverFuture {
|
||||
/// Returns a controller for the prover for advanced functionality.
|
||||
pub fn control(&self) -> ProverControl {
|
||||
self.ctrl.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for ProverFuture {
|
||||
type Output = Result<Prover<state::Committed>, ProverError>;
|
||||
|
||||
fn poll(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Self::Output> {
|
||||
self.fut.as_mut().poll(cx)
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,17 @@ use tlsn_core::{
|
||||
|
||||
use crate::{
|
||||
prover::ProverError,
|
||||
transcript_internal::{TranscriptRefs, auth::prove_plaintext, commit::hash::prove_hash},
|
||||
transcript_internal::{
|
||||
TranscriptRefs,
|
||||
auth::prove_plaintext,
|
||||
commit::{
|
||||
encoding::{self, MacStore},
|
||||
hash::prove_hash,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
pub(crate) async fn prove<T: Vm<Binary> + Send + Sync>(
|
||||
pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
|
||||
ctx: &mut Context,
|
||||
vm: &mut T,
|
||||
keys: &SessionKeys,
|
||||
@@ -38,6 +45,13 @@ pub(crate) async fn prove<T: Vm<Binary> + Send + Sync>(
|
||||
Direction::Sent => commit_sent.union_mut(idx),
|
||||
Direction::Received => commit_recv.union_mut(idx),
|
||||
});
|
||||
|
||||
commit_config
|
||||
.iter_encoding()
|
||||
.for_each(|(direction, idx)| match direction {
|
||||
Direction::Sent => commit_sent.union_mut(idx),
|
||||
Direction::Received => commit_recv.union_mut(idx),
|
||||
});
|
||||
}
|
||||
|
||||
let transcript_refs = TranscriptRefs {
|
||||
@@ -88,6 +102,45 @@ pub(crate) async fn prove<T: Vm<Binary> + Send + Sync>(
|
||||
|
||||
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
|
||||
|
||||
if let Some(commit_config) = config.transcript_commit()
|
||||
&& commit_config.has_encoding()
|
||||
{
|
||||
let mut sent_ranges = RangeSet::default();
|
||||
let mut recv_ranges = RangeSet::default();
|
||||
for (dir, idx) in commit_config.iter_encoding() {
|
||||
match dir {
|
||||
Direction::Sent => sent_ranges.union_mut(idx),
|
||||
Direction::Received => recv_ranges.union_mut(idx),
|
||||
}
|
||||
}
|
||||
|
||||
let sent_map = transcript_refs
|
||||
.sent
|
||||
.index(&sent_ranges)
|
||||
.expect("indices are valid");
|
||||
let recv_map = transcript_refs
|
||||
.recv
|
||||
.index(&recv_ranges)
|
||||
.expect("indices are valid");
|
||||
|
||||
let (commitment, tree) = encoding::receive(
|
||||
ctx,
|
||||
vm,
|
||||
*commit_config.encoding_hash_alg(),
|
||||
&sent_map,
|
||||
&recv_map,
|
||||
commit_config.iter_encoding(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
output
|
||||
.transcript_commitments
|
||||
.push(TranscriptCommitment::Encoding(commitment));
|
||||
output
|
||||
.transcript_secrets
|
||||
.push(TranscriptSecret::Encoding(tree));
|
||||
}
|
||||
|
||||
if let Some((hash_fut, hash_secrets)) = hash_commitments {
|
||||
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
|
||||
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_plex::DuplexStream;
|
||||
use mpc_tls::{MpcTlsLeader, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use tlsn_core::{
|
||||
@@ -14,6 +15,10 @@ use tokio::sync::Mutex;
|
||||
use crate::{
|
||||
mpz::{ProverMpc, ProverZk},
|
||||
mux::{MuxControl, MuxFuture},
|
||||
prover::{
|
||||
ProverError,
|
||||
client::{TlsClient, TlsOutput},
|
||||
},
|
||||
};
|
||||
|
||||
/// Entry state
|
||||
@@ -24,6 +29,7 @@ opaque_debug::implement!(Initialized);
|
||||
/// State after the verifier has accepted the proposed TLS commitment protocol
|
||||
/// configuration and preprocessing has completed.
|
||||
pub struct CommitAccepted {
|
||||
pub(crate) verifier_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) mpc_tls: MpcTlsLeader,
|
||||
@@ -33,8 +39,49 @@ pub struct CommitAccepted {
|
||||
|
||||
opaque_debug::implement!(CommitAccepted);
|
||||
|
||||
/// State when the TLS client has been setup.
|
||||
pub struct Setup {
|
||||
pub(crate) verifier_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) server_name: ServerName,
|
||||
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
|
||||
pub(crate) client_io: DuplexStream,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Setup);
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// State during the MPC-TLS connection.
|
||||
#[project = ConnectedProj]
|
||||
pub struct Connected<S, T> {
|
||||
#[pin]
|
||||
pub(crate) verifier_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) server_name: ServerName,
|
||||
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
|
||||
#[pin]
|
||||
pub(crate) client_io: DuplexStream,
|
||||
pub(crate) output: Option<TlsOutput>,
|
||||
#[pin]
|
||||
pub(crate) server_socket: S,
|
||||
#[pin]
|
||||
pub(crate) verifier_socket: T,
|
||||
#[pin]
|
||||
pub(crate) tls_client_to_server_buf: DuplexStream,
|
||||
#[pin]
|
||||
pub(crate) server_to_tls_client_buf: DuplexStream,
|
||||
pub(crate) client_closed: bool,
|
||||
pub(crate) server_closed: bool
|
||||
}
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Connected<S, T>);
|
||||
|
||||
/// State after the TLS transcript has been committed.
|
||||
pub struct Committed {
|
||||
pub(crate) verifier_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
@@ -52,11 +99,15 @@ pub trait ProverState: sealed::Sealed {}
|
||||
|
||||
impl ProverState for Initialized {}
|
||||
impl ProverState for CommitAccepted {}
|
||||
impl ProverState for Setup {}
|
||||
impl<S, T> ProverState for Connected<S, T> {}
|
||||
impl ProverState for Committed {}
|
||||
|
||||
mod sealed {
|
||||
pub trait Sealed {}
|
||||
impl Sealed for super::Initialized {}
|
||||
impl Sealed for super::CommitAccepted {}
|
||||
impl Sealed for super::Setup {}
|
||||
impl<S, T> Sealed for super::Connected<S, T> {}
|
||||
impl Sealed for super::Committed {}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//! Plaintext commitment and proof of encryption.
|
||||
|
||||
pub(crate) mod encoding;
|
||||
pub(crate) mod hash;
|
||||
|
||||
267
crates/tlsn/src/transcript_internal/commit/encoding.rs
Normal file
267
crates/tlsn/src/transcript_internal/commit/encoding.rs
Normal file
@@ -0,0 +1,267 @@
|
||||
//! Encoding commitment protocol.
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::{
|
||||
Vector,
|
||||
binary::U8,
|
||||
correlated::{Delta, Key, Mac},
|
||||
};
|
||||
use rand::Rng;
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use tlsn_core::{
|
||||
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
|
||||
transcript::{
|
||||
Direction,
|
||||
encoding::{
|
||||
Encoder, EncoderSecret, EncodingCommitment, EncodingProvider, EncodingProviderError,
|
||||
EncodingTree, EncodingTreeError, new_encoder,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
map::{Item, RangeMap},
|
||||
transcript_internal::ReferenceMap,
|
||||
};
|
||||
|
||||
/// Bytes of encoding, per byte.
|
||||
const ENCODING_SIZE: usize = 128;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Encodings {
|
||||
sent: Vec<u8>,
|
||||
recv: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Transfers encodings for the provided plaintext ranges.
|
||||
pub(crate) async fn transfer<K: KeyStore>(
|
||||
ctx: &mut Context,
|
||||
store: &K,
|
||||
sent: &ReferenceMap,
|
||||
recv: &ReferenceMap,
|
||||
) -> Result<(EncoderSecret, EncodingCommitment), EncodingError> {
|
||||
let secret = EncoderSecret::new(rand::rng().random(), store.delta().as_block().to_bytes());
|
||||
let encoder = new_encoder(&secret);
|
||||
|
||||
// Collects the encodings for the provided plaintext ranges.
|
||||
fn collect_encodings(
|
||||
encoder: &impl Encoder,
|
||||
store: &impl KeyStore,
|
||||
direction: Direction,
|
||||
map: &ReferenceMap,
|
||||
) -> Vec<u8> {
|
||||
let mut encodings = Vec::with_capacity(map.len() * ENCODING_SIZE);
|
||||
for (range, chunk) in map.iter() {
|
||||
let start = encodings.len();
|
||||
encoder.encode_range(direction, range, &mut encodings);
|
||||
let keys = store
|
||||
.get_keys(*chunk)
|
||||
.expect("keys are present for provided plaintext ranges");
|
||||
encodings[start..]
|
||||
.iter_mut()
|
||||
.zip(keys.iter().flat_map(|key| key.as_block().as_bytes()))
|
||||
.for_each(|(encoding, key)| {
|
||||
*encoding ^= *key;
|
||||
});
|
||||
}
|
||||
encodings
|
||||
}
|
||||
|
||||
let encodings = Encodings {
|
||||
sent: collect_encodings(&encoder, store, Direction::Sent, sent),
|
||||
recv: collect_encodings(&encoder, store, Direction::Received, recv),
|
||||
};
|
||||
|
||||
let frame_limit = ctx
|
||||
.io()
|
||||
.limit()
|
||||
.saturating_add(encodings.sent.len() + encodings.recv.len());
|
||||
ctx.io_mut().with_limit(frame_limit).send(encodings).await?;
|
||||
|
||||
let root = ctx.io_mut().expect_next().await?;
|
||||
|
||||
Ok((secret, EncodingCommitment { root }))
|
||||
}
|
||||
|
||||
/// Receives and commits to the encodings for the provided plaintext ranges.
|
||||
pub(crate) async fn receive<M: MacStore>(
|
||||
ctx: &mut Context,
|
||||
store: &M,
|
||||
hash_alg: HashAlgId,
|
||||
sent: &ReferenceMap,
|
||||
recv: &ReferenceMap,
|
||||
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
|
||||
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
|
||||
let hasher: &(dyn HashAlgorithm + Send + Sync) = match hash_alg {
|
||||
HashAlgId::SHA256 => &Sha256::default(),
|
||||
HashAlgId::KECCAK256 => &Keccak256::default(),
|
||||
HashAlgId::BLAKE3 => &Blake3::default(),
|
||||
alg => {
|
||||
return Err(ErrorRepr::UnsupportedHashAlgorithm(alg).into());
|
||||
}
|
||||
};
|
||||
|
||||
let (sent_len, recv_len) = (sent.len(), recv.len());
|
||||
let frame_limit = ctx
|
||||
.io()
|
||||
.limit()
|
||||
.saturating_add(ENCODING_SIZE * (sent_len + recv_len));
|
||||
let encodings: Encodings = ctx.io_mut().with_limit(frame_limit).expect_next().await?;
|
||||
|
||||
if encodings.sent.len() != sent_len * ENCODING_SIZE {
|
||||
return Err(ErrorRepr::IncorrectMacCount {
|
||||
direction: Direction::Sent,
|
||||
expected: sent_len,
|
||||
got: encodings.sent.len() / ENCODING_SIZE,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
if encodings.recv.len() != recv_len * ENCODING_SIZE {
|
||||
return Err(ErrorRepr::IncorrectMacCount {
|
||||
direction: Direction::Received,
|
||||
expected: recv_len,
|
||||
got: encodings.recv.len() / ENCODING_SIZE,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
// Collects a map of plaintext ranges to their encodings.
|
||||
fn collect_map(
|
||||
store: &impl MacStore,
|
||||
mut encodings: Vec<u8>,
|
||||
map: &ReferenceMap,
|
||||
) -> RangeMap<EncodingSlice> {
|
||||
let mut encoding_map = Vec::new();
|
||||
let mut pos = 0;
|
||||
for (range, chunk) in map.iter() {
|
||||
let macs = store
|
||||
.get_macs(*chunk)
|
||||
.expect("MACs are present for provided plaintext ranges");
|
||||
let encoding = &mut encodings[pos..pos + range.len() * ENCODING_SIZE];
|
||||
encoding
|
||||
.iter_mut()
|
||||
.zip(macs.iter().flat_map(|mac| mac.as_bytes()))
|
||||
.for_each(|(encoding, mac)| {
|
||||
*encoding ^= *mac;
|
||||
});
|
||||
|
||||
encoding_map.push((range.start, EncodingSlice::from(&(*encoding))));
|
||||
pos += range.len() * ENCODING_SIZE;
|
||||
}
|
||||
RangeMap::new(encoding_map)
|
||||
}
|
||||
|
||||
let provider = Provider {
|
||||
sent: collect_map(store, encodings.sent, sent),
|
||||
recv: collect_map(store, encodings.recv, recv),
|
||||
};
|
||||
|
||||
let tree = EncodingTree::new(hasher, idxs, &provider)?;
|
||||
let root = tree.root();
|
||||
|
||||
ctx.io_mut().send(root.clone()).await?;
|
||||
|
||||
let commitment = EncodingCommitment { root };
|
||||
|
||||
Ok((commitment, tree))
|
||||
}
|
||||
|
||||
pub(crate) trait KeyStore {
|
||||
fn delta(&self) -> Δ
|
||||
|
||||
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
|
||||
}
|
||||
|
||||
pub(crate) trait MacStore {
|
||||
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Provider {
|
||||
sent: RangeMap<EncodingSlice>,
|
||||
recv: RangeMap<EncodingSlice>,
|
||||
}
|
||||
|
||||
impl EncodingProvider for Provider {
|
||||
fn provide_encoding(
|
||||
&self,
|
||||
direction: Direction,
|
||||
range: Range<usize>,
|
||||
dest: &mut Vec<u8>,
|
||||
) -> Result<(), EncodingProviderError> {
|
||||
let encodings = match direction {
|
||||
Direction::Sent => &self.sent,
|
||||
Direction::Received => &self.recv,
|
||||
};
|
||||
|
||||
let encoding = encodings.get(range).ok_or(EncodingProviderError)?;
|
||||
|
||||
dest.extend_from_slice(encoding);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodingSlice(Vec<u8>);
|
||||
|
||||
impl From<&[u8]> for EncodingSlice {
|
||||
fn from(value: &[u8]) -> Self {
|
||||
Self(value.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Item for EncodingSlice {
|
||||
type Slice<'a>
|
||||
= &'a [u8]
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
fn length(&self) -> usize {
|
||||
self.0.len() / ENCODING_SIZE
|
||||
}
|
||||
|
||||
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
|
||||
self.0
|
||||
.get(range.start * ENCODING_SIZE..range.end * ENCODING_SIZE)
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoding protocol error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct EncodingError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("encoding protocol error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("I/O error: {0}")]
|
||||
Io(std::io::Error),
|
||||
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
|
||||
IncorrectMacCount {
|
||||
direction: Direction,
|
||||
expected: usize,
|
||||
got: usize,
|
||||
},
|
||||
#[error("encoding tree error: {0}")]
|
||||
EncodingTree(EncodingTreeError),
|
||||
#[error("unsupported hash algorithm: {0}")]
|
||||
UnsupportedHashAlgorithm(HashAlgId),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for EncodingError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self(ErrorRepr::Io(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodingTreeError> for EncodingError {
|
||||
fn from(value: EncodingTreeError) -> Self {
|
||||
Self(ErrorRepr::EncodingTree(value))
|
||||
}
|
||||
}
|
||||
143
crates/tlsn/src/utils.rs
Normal file
143
crates/tlsn/src/utils.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
//! Execution context.
|
||||
|
||||
use std::{
|
||||
io::ErrorKind,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite, future::FusedFuture};
|
||||
use futures_plex::DuplexStream;
|
||||
use mpz_common::context::Multithread;
|
||||
|
||||
use crate::mux::MuxControl;
|
||||
|
||||
/// Maximum concurrency for multi-threaded context.
|
||||
pub(crate) const MAX_CONCURRENCY: usize = 8;
|
||||
|
||||
/// Builds a multi-threaded context with the given muxer.
|
||||
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
|
||||
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
|
||||
|
||||
#[cfg(all(feature = "web", target_arch = "wasm32"))]
|
||||
let builder = builder.spawn_handler(|f| {
|
||||
let _ = web_spawn::spawn(f);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
builder.build().unwrap()
|
||||
}
|
||||
|
||||
/// Polls the future while copying bytes between two duplex streams.
|
||||
///
|
||||
/// Returns as soon as the future is ready, without closing IO.
|
||||
pub(crate) async fn await_with_copy_io<'a, S, T>(
|
||||
mut fut: Pin<Box<dyn FusedFuture<Output = T> + Send + 'a>>,
|
||||
io: S,
|
||||
duplex: &mut DuplexStream,
|
||||
) -> T
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
let mut copy = CopyIo::new(io, duplex);
|
||||
|
||||
loop {
|
||||
futures::select! {
|
||||
_ = copy => (),
|
||||
output = fut => break output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CopyIo<'a, S> {
|
||||
#[pin]
|
||||
io: S,
|
||||
#[pin]
|
||||
duplex: &'a mut DuplexStream,
|
||||
io_done: bool,
|
||||
duplex_done: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S> CopyIo<'a, S> {
|
||||
pub(crate) fn new(io: S, duplex: &'a mut DuplexStream) -> Self {
|
||||
Self {
|
||||
io,
|
||||
duplex,
|
||||
io_done: false,
|
||||
duplex_done: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S> Future for CopyIo<'a, S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
type Output = std::io::Result<()>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
let mut is_pending = true;
|
||||
|
||||
if !*this.duplex_done {
|
||||
match this.duplex.poll_read_to(cx, this.io.as_mut()) {
|
||||
Poll::Ready(Ok(read)) if read > 0 => is_pending = false,
|
||||
Poll::Ready(Ok(_)) => {
|
||||
is_pending = false;
|
||||
*this.duplex_done = true;
|
||||
}
|
||||
Poll::Ready(Err(err))
|
||||
if err.kind() == ErrorKind::BrokenPipe
|
||||
|| err.kind() == ErrorKind::ConnectionReset
|
||||
|| err.kind() == ErrorKind::NotConnected =>
|
||||
{
|
||||
is_pending = false;
|
||||
*this.duplex_done = true;
|
||||
}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => (),
|
||||
}
|
||||
}
|
||||
|
||||
if !*this.io_done {
|
||||
match this.duplex.poll_write_from(cx, this.io.as_mut()) {
|
||||
Poll::Ready(Ok(write)) if write > 0 => is_pending = false,
|
||||
Poll::Ready(Ok(_)) => {
|
||||
is_pending = false;
|
||||
*this.io_done = true;
|
||||
}
|
||||
Poll::Ready(Err(err))
|
||||
if err.kind() == ErrorKind::BrokenPipe
|
||||
|| err.kind() == ErrorKind::ConnectionReset
|
||||
|| err.kind() == ErrorKind::NotConnected =>
|
||||
{
|
||||
is_pending = false;
|
||||
*this.io_done = true
|
||||
}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => (),
|
||||
}
|
||||
}
|
||||
|
||||
if *this.io_done || *this.duplex_done {
|
||||
return Poll::Ready(Ok(()));
|
||||
} else if is_pending {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S> FusedFuture for CopyIo<'a, S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
{
|
||||
fn is_terminated(&self) -> bool {
|
||||
self.duplex_done || self.io_done
|
||||
}
|
||||
}
|
||||
@@ -10,14 +10,14 @@ pub use error::VerifierError;
|
||||
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
|
||||
|
||||
use crate::{
|
||||
Role,
|
||||
context::build_mt_context,
|
||||
BUF_CAP, Role,
|
||||
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
|
||||
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
|
||||
mux::attach_mux,
|
||||
tag::verify_tags,
|
||||
utils::{CopyIo, await_with_copy_io, build_mt_context},
|
||||
};
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, TryFutureExt};
|
||||
use mpz_vm_core::prelude::*;
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use tlsn_core::{
|
||||
@@ -66,48 +66,73 @@ impl Verifier<state::Initialized> {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `socket` - The socket to the prover.
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
self,
|
||||
socket: S,
|
||||
mut prover_io: S,
|
||||
) -> Result<Verifier<state::CommitStart>, VerifierError> {
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
|
||||
let mut mt = build_mt_context(mux_ctrl.clone());
|
||||
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
|
||||
let (duplex_a, mut duplex_b) = futures_plex::duplex(BUF_CAP);
|
||||
|
||||
// Receives protocol configuration from prover to perform compatibility check.
|
||||
let TlsCommitRequestMsg { request, version } =
|
||||
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(duplex_a, Role::Verifier);
|
||||
let mut mt = build_mt_context(mux_ctrl.clone());
|
||||
|
||||
let fut = Box::pin(
|
||||
async {
|
||||
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
|
||||
|
||||
// Receives protocol configuration from prover to perform compatibility check.
|
||||
let TlsCommitRequestMsg { request, version } =
|
||||
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
|
||||
|
||||
Ok::<_, VerifierError>((request, version, ctx))
|
||||
}
|
||||
.fuse(),
|
||||
);
|
||||
|
||||
let (request, version, mut ctx) =
|
||||
await_with_copy_io(fut, &mut prover_io, &mut duplex_b).await?;
|
||||
|
||||
if version != *crate::VERSION {
|
||||
let msg = format!(
|
||||
"prover version does not match with verifier: {version} != {}",
|
||||
*crate::VERSION
|
||||
);
|
||||
mux_fut
|
||||
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
|
||||
.await?;
|
||||
|
||||
// Wait for the prover to correctly close the connection.
|
||||
if !mux_fut.is_complete() {
|
||||
mux_ctrl.close();
|
||||
mux_fut.await?;
|
||||
}
|
||||
let fut = Box::pin(
|
||||
async {
|
||||
mux_fut
|
||||
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
|
||||
.await?;
|
||||
|
||||
return Err(VerifierError::config(msg));
|
||||
// Wait for the prover to correctly close the connection.
|
||||
if !mux_fut.is_complete() {
|
||||
mux_ctrl.close();
|
||||
mux_fut.await?;
|
||||
}
|
||||
|
||||
Err(VerifierError::config(msg))
|
||||
}
|
||||
.fuse(),
|
||||
);
|
||||
let copy = CopyIo::new(prover_io, &mut duplex_b).map_err(VerifierError::from);
|
||||
let (config_err, _) = futures::try_join!(fut, copy)?;
|
||||
|
||||
return Err(config_err);
|
||||
}
|
||||
|
||||
Ok(Verifier {
|
||||
let verifier = Verifier {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::CommitStart {
|
||||
prover_io: Some(duplex_b),
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
request,
|
||||
},
|
||||
})
|
||||
};
|
||||
Ok(verifier)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,13 +143,36 @@ impl Verifier<state::CommitStart> {
|
||||
}
|
||||
|
||||
/// Accepts the proposed protocol configuration.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
pub async fn accept<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
mut self,
|
||||
prover_io: S,
|
||||
) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
|
||||
let mut duplex = self
|
||||
.state
|
||||
.prover_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = Box::pin(self.accept_inner().fuse());
|
||||
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
|
||||
|
||||
verifier.state.prover_io = Some(duplex);
|
||||
Ok(verifier)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
|
||||
async fn accept_inner(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
|
||||
let state::CommitStart {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
request,
|
||||
..
|
||||
} = self.state;
|
||||
|
||||
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
|
||||
@@ -151,6 +199,7 @@ impl Verifier<state::CommitStart> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::CommitAccepted {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mpc_tls,
|
||||
@@ -161,8 +210,31 @@ impl Verifier<state::CommitStart> {
|
||||
}
|
||||
|
||||
/// Rejects the proposed protocol configuration.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
/// * `msg` - The optional rejection message.
|
||||
pub async fn reject<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
mut self,
|
||||
prover_io: S,
|
||||
msg: Option<&str>,
|
||||
) -> Result<(), VerifierError> {
|
||||
let mut duplex = self
|
||||
.state
|
||||
.prover_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = self.reject_inner(msg);
|
||||
let copy = CopyIo::new(prover_io, &mut duplex).map_err(VerifierError::from);
|
||||
|
||||
futures::try_join!(fut, copy)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn reject(self, msg: Option<&str>) -> Result<(), VerifierError> {
|
||||
async fn reject_inner(self, msg: Option<&str>) -> Result<(), VerifierError> {
|
||||
let state::CommitStart {
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
@@ -186,9 +258,31 @@ impl Verifier<state::CommitStart> {
|
||||
|
||||
impl Verifier<state::CommitAccepted> {
|
||||
/// Runs the verifier until the TLS connection is closed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
pub async fn run<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
mut self,
|
||||
prover_io: S,
|
||||
) -> Result<Verifier<state::Committed>, VerifierError> {
|
||||
let mut duplex = self
|
||||
.state
|
||||
.prover_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = Box::pin(self.run_inner().fuse());
|
||||
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
|
||||
|
||||
verifier.state.prover_io = Some(duplex);
|
||||
Ok(verifier)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
|
||||
async fn run_inner(self) -> Result<Verifier<state::Committed>, VerifierError> {
|
||||
let state::CommitAccepted {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mpc_tls,
|
||||
@@ -232,6 +326,7 @@ impl Verifier<state::CommitAccepted> {
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
|
||||
debug!("verifying tags");
|
||||
mux_fut
|
||||
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
|
||||
.await?;
|
||||
@@ -241,10 +336,12 @@ impl Verifier<state::CommitAccepted> {
|
||||
// authenticated from the verifier's perspective.
|
||||
tag_proof.verify().map_err(VerifierError::zk)?;
|
||||
|
||||
debug!("MPC-TLS done");
|
||||
Ok(Verifier {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -263,9 +360,31 @@ impl Verifier<state::Committed> {
|
||||
}
|
||||
|
||||
/// Begins verification of statements from the prover.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
pub async fn verify<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
mut self,
|
||||
prover_io: S,
|
||||
) -> Result<Verifier<state::Verify>, VerifierError> {
|
||||
let mut duplex = self
|
||||
.state
|
||||
.prover_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = Box::pin(self.verify_inner().fuse());
|
||||
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
|
||||
|
||||
verifier.state.prover_io = Some(duplex);
|
||||
Ok(verifier)
|
||||
}
|
||||
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> {
|
||||
async fn verify_inner(self) -> Result<Verifier<state::Verify>, VerifierError> {
|
||||
let state::Committed {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
@@ -286,6 +405,7 @@ impl Verifier<state::Committed> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Verify {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -300,18 +420,36 @@ impl Verifier<state::Committed> {
|
||||
}
|
||||
|
||||
/// Closes the connection with the prover.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
#[instrument(parent = &self.span, level = "info", skip_all, err)]
|
||||
pub async fn close(self) -> Result<(), VerifierError> {
|
||||
let state::Committed {
|
||||
mux_ctrl, mux_fut, ..
|
||||
} = self.state;
|
||||
pub async fn close<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
mut self,
|
||||
mut prover_io: S,
|
||||
) -> Result<(), VerifierError> {
|
||||
let state::Committed { mux_fut, .. } = self.state;
|
||||
|
||||
// Wait for the prover to correctly close the connection.
|
||||
if !mux_fut.is_complete() {
|
||||
mux_ctrl.close();
|
||||
mux_fut.await?;
|
||||
}
|
||||
let mut duplex = self
|
||||
.state
|
||||
.prover_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
duplex.close().await?;
|
||||
|
||||
let fut: Box<dyn Future<Output = Result<(), VerifierError>> + Send + Unpin> =
|
||||
if mux_fut.is_complete() {
|
||||
Box::new(futures::future::ready(Ok::<_, VerifierError>(())))
|
||||
} else {
|
||||
Box::new(mux_fut.map_err(VerifierError::from))
|
||||
};
|
||||
|
||||
let copy = CopyIo::new(&mut prover_io, &mut duplex).map_err(VerifierError::from);
|
||||
futures::try_join!(fut, copy)?;
|
||||
|
||||
prover_io.write_all(b"close").await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -323,10 +461,32 @@ impl Verifier<state::Verify> {
|
||||
}
|
||||
|
||||
/// Accepts the proving request.
|
||||
pub async fn accept(
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
pub async fn accept<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
mut self,
|
||||
prover_io: S,
|
||||
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
|
||||
let mut duplex = self
|
||||
.state
|
||||
.prover_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = Box::pin(self.accept_inner().fuse());
|
||||
let (output, mut verifier) = await_with_copy_io(fut, prover_io, &mut duplex).await?;
|
||||
|
||||
verifier.state.prover_io = Some(duplex);
|
||||
Ok((output, verifier))
|
||||
}
|
||||
|
||||
async fn accept_inner(
|
||||
self,
|
||||
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
|
||||
let state::Verify {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
@@ -362,6 +522,7 @@ impl Verifier<state::Verify> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
@@ -374,11 +535,35 @@ impl Verifier<state::Verify> {
|
||||
}
|
||||
|
||||
/// Rejects the proving request.
|
||||
pub async fn reject(
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `prover_io` - The IO to the prover.
|
||||
/// * `msg` - The optional rejection message.
|
||||
pub async fn reject<S: AsyncWrite + AsyncRead + Send + Unpin>(
|
||||
mut self,
|
||||
prover_io: S,
|
||||
msg: Option<&str>,
|
||||
) -> Result<Verifier<state::Committed>, VerifierError> {
|
||||
let mut duplex = self
|
||||
.state
|
||||
.prover_io
|
||||
.take()
|
||||
.expect("duplex should be available");
|
||||
|
||||
let fut = Box::pin(self.reject_inner(msg).fuse());
|
||||
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
|
||||
|
||||
verifier.state.prover_io = Some(duplex);
|
||||
Ok(verifier)
|
||||
}
|
||||
|
||||
async fn reject_inner(
|
||||
self,
|
||||
msg: Option<&str>,
|
||||
) -> Result<Verifier<state::Committed>, VerifierError> {
|
||||
let state::Verify {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mut ctx,
|
||||
@@ -396,6 +581,7 @@ impl Verifier<state::Verify> {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Committed {
|
||||
prover_io,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
ctx,
|
||||
|
||||
@@ -2,6 +2,8 @@ use std::{error::Error, fmt};
|
||||
|
||||
use mpc_tls::MpcTlsError;
|
||||
|
||||
use crate::transcript_internal::commit::encoding::EncodingError;
|
||||
|
||||
/// Error for [`Verifier`](crate::verifier::Verifier).
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct VerifierError {
|
||||
@@ -55,6 +57,7 @@ enum ErrorKind {
|
||||
Config,
|
||||
Mpc,
|
||||
Zk,
|
||||
Commit,
|
||||
Verify,
|
||||
}
|
||||
|
||||
@@ -67,6 +70,7 @@ impl fmt::Display for VerifierError {
|
||||
ErrorKind::Config => f.write_str("config error")?,
|
||||
ErrorKind::Mpc => f.write_str("mpc error")?,
|
||||
ErrorKind::Zk => f.write_str("zk error")?,
|
||||
ErrorKind::Commit => f.write_str("commit error")?,
|
||||
ErrorKind::Verify => f.write_str("verification error")?,
|
||||
}
|
||||
|
||||
@@ -101,3 +105,9 @@ impl From<MpcTlsError> for VerifierError {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodingError> for VerifierError {
|
||||
fn from(e: EncodingError) -> Self {
|
||||
Self::new(ErrorKind::Commit, e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::mux::{MuxControl, MuxFuture};
|
||||
use futures_plex::DuplexStream;
|
||||
use mpc_tls::{MpcTlsFollower, SessionKeys};
|
||||
use mpz_common::Context;
|
||||
use tlsn_core::{
|
||||
@@ -25,6 +26,7 @@ opaque_debug::implement!(Initialized);
|
||||
|
||||
/// State after receiving protocol configuration from the prover.
|
||||
pub struct CommitStart {
|
||||
pub(crate) prover_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
@@ -36,6 +38,7 @@ opaque_debug::implement!(CommitStart);
|
||||
/// State after accepting the proposed TLS commitment protocol configuration and
|
||||
/// performing preprocessing.
|
||||
pub struct CommitAccepted {
|
||||
pub(crate) prover_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) mpc_tls: MpcTlsFollower,
|
||||
@@ -47,6 +50,7 @@ opaque_debug::implement!(CommitAccepted);
|
||||
|
||||
/// State after the TLS transcript has been committed.
|
||||
pub struct Committed {
|
||||
pub(crate) prover_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
@@ -59,6 +63,7 @@ opaque_debug::implement!(Committed);
|
||||
|
||||
/// State after receiving a proving request.
|
||||
pub struct Verify {
|
||||
pub(crate) prover_io: Option<DuplexStream>,
|
||||
pub(crate) mux_ctrl: MuxControl,
|
||||
pub(crate) mux_fut: MuxFuture,
|
||||
pub(crate) ctx: Context,
|
||||
|
||||
@@ -14,12 +14,19 @@ use tlsn_core::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
transcript_internal::{TranscriptRefs, auth::verify_plaintext, commit::hash::verify_hash},
|
||||
transcript_internal::{
|
||||
TranscriptRefs,
|
||||
auth::verify_plaintext,
|
||||
commit::{
|
||||
encoding::{self, KeyStore},
|
||||
hash::verify_hash,
|
||||
},
|
||||
},
|
||||
verifier::VerifierError,
|
||||
};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
|
||||
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
ctx: &mut Context,
|
||||
vm: &mut T,
|
||||
keys: &SessionKeys,
|
||||
@@ -87,6 +94,11 @@ pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
|
||||
Direction::Sent => commit_sent.union_mut(idx),
|
||||
Direction::Received => commit_recv.union_mut(idx),
|
||||
});
|
||||
|
||||
if let Some((sent, recv)) = commit_config.encoding() {
|
||||
commit_sent.union_mut(sent);
|
||||
commit_recv.union_mut(recv);
|
||||
}
|
||||
}
|
||||
|
||||
let (sent_refs, sent_proof) = verify_plaintext(
|
||||
@@ -139,6 +151,24 @@ pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
|
||||
sent_proof.verify().map_err(VerifierError::verify)?;
|
||||
recv_proof.verify().map_err(VerifierError::verify)?;
|
||||
|
||||
let mut encoder_secret = None;
|
||||
if let Some(commit_config) = request.transcript_commit()
|
||||
&& let Some((sent, recv)) = commit_config.encoding()
|
||||
{
|
||||
let sent_map = transcript_refs
|
||||
.sent
|
||||
.index(sent)
|
||||
.expect("ranges were authenticated");
|
||||
let recv_map = transcript_refs
|
||||
.recv
|
||||
.index(recv)
|
||||
.expect("ranges were authenticated");
|
||||
|
||||
let (secret, commitment) = encoding::transfer(ctx, vm, &sent_map, &recv_map).await?;
|
||||
encoder_secret = Some(secret);
|
||||
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
|
||||
}
|
||||
|
||||
if let Some(hash_commitments) = hash_commitments {
|
||||
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
|
||||
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
|
||||
@@ -148,6 +178,7 @@ pub(crate) async fn verify<T: Vm<Binary> + Send + Sync>(
|
||||
Ok(VerifierOutput {
|
||||
server_name,
|
||||
transcript: request.reveal().is_some().then_some(transcript),
|
||||
encoder_secret,
|
||||
transcript_commitments,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
@@ -8,9 +9,12 @@ use tlsn::{
|
||||
verifier::VerifierConfig,
|
||||
},
|
||||
connection::ServerName,
|
||||
hash::HashAlgId,
|
||||
hash::{HashAlgId, HashProvider},
|
||||
prover::Prover,
|
||||
transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitmentKind},
|
||||
transcript::{
|
||||
Direction, Transcript, TranscriptCommitConfig, TranscriptCommitment,
|
||||
TranscriptCommitmentKind, TranscriptSecret,
|
||||
},
|
||||
verifier::{Verifier, VerifierOutput},
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
};
|
||||
@@ -38,7 +42,7 @@ async fn test() {
|
||||
|
||||
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
|
||||
|
||||
let ((_full_transcript, _prover_output), verifier_output) =
|
||||
let ((full_transcript, prover_output), verifier_output) =
|
||||
tokio::join!(prover(socket_0), verifier(socket_1));
|
||||
|
||||
let partial_transcript = verifier_output.transcript.unwrap();
|
||||
@@ -54,6 +58,50 @@ async fn test() {
|
||||
partial_transcript.received_authed().iter().next().unwrap(),
|
||||
0..10
|
||||
);
|
||||
|
||||
let encoding_tree = prover_output
|
||||
.transcript_secrets
|
||||
.iter()
|
||||
.find_map(|secret| {
|
||||
if let TranscriptSecret::Encoding(tree) = secret {
|
||||
Some(tree)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let encoding_commitment = prover_output
|
||||
.transcript_commitments
|
||||
.iter()
|
||||
.find_map(|commitment| {
|
||||
if let TranscriptCommitment::Encoding(commitment) = commitment {
|
||||
Some(commitment)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let prove_sent = RangeSet::from(1..full_transcript.sent().len() - 1);
|
||||
let prove_recv = RangeSet::from(1..full_transcript.received().len() - 1);
|
||||
let idxs = [
|
||||
(Direction::Sent, prove_sent.clone()),
|
||||
(Direction::Received, prove_recv.clone()),
|
||||
];
|
||||
let proof = encoding_tree.proof(idxs.iter()).unwrap();
|
||||
let (auth_sent, auth_recv) = proof
|
||||
.verify_with_provider(
|
||||
&HashProvider::default(),
|
||||
&verifier_output.encoder_secret.unwrap(),
|
||||
encoding_commitment,
|
||||
full_transcript.sent(),
|
||||
full_transcript.received(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth_sent, prove_sent);
|
||||
assert_eq!(auth_recv, prove_recv);
|
||||
}
|
||||
|
||||
#[instrument(skip(verifier_socket))]
|
||||
@@ -62,7 +110,11 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
) -> (Transcript, ProverOutput) {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
|
||||
|
||||
let server_task = tokio::spawn(bind(server_socket.compat()));
|
||||
let client_socket = client_socket.compat();
|
||||
let server_socket = server_socket.compat();
|
||||
let mut verifier_socket = verifier_socket.compat();
|
||||
|
||||
let server_task = tokio::spawn(bind(server_socket));
|
||||
|
||||
let prover = Prover::new(ProverConfig::builder().build().unwrap())
|
||||
.commit(
|
||||
@@ -78,13 +130,13 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
)
|
||||
.build()
|
||||
.unwrap(),
|
||||
verifier_socket.compat(),
|
||||
&mut verifier_socket,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (mut tls_connection, prover_fut) = prover
|
||||
.connect(
|
||||
let (mut tls_connection, prover) = prover
|
||||
.setup(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
|
||||
.root_store(RootCertStore {
|
||||
@@ -92,44 +144,47 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
})
|
||||
.build()
|
||||
.unwrap(),
|
||||
client_socket.compat(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
|
||||
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
|
||||
|
||||
tls_connection
|
||||
.write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
tls_connection.close().await.unwrap();
|
||||
|
||||
let mut response = vec![0u8; 1024];
|
||||
tls_connection.read_to_end(&mut response).await.unwrap();
|
||||
|
||||
tls_connection.close().await.unwrap();
|
||||
let _ = server_task.await.unwrap();
|
||||
|
||||
let mut prover = prover_task.await.unwrap().unwrap();
|
||||
let (mut prover, _, mut verifier_socket) = prover_task.await.unwrap().unwrap();
|
||||
let sent_tx_len = prover.transcript().sent().len();
|
||||
let recv_tx_len = prover.transcript().received().len();
|
||||
|
||||
let mut builder = TranscriptCommitConfig::builder(prover.transcript());
|
||||
|
||||
let kind = TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256,
|
||||
};
|
||||
builder
|
||||
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
|
||||
.unwrap();
|
||||
for kind in [
|
||||
TranscriptCommitmentKind::Encoding,
|
||||
TranscriptCommitmentKind::Hash {
|
||||
alg: HashAlgId::SHA256,
|
||||
},
|
||||
] {
|
||||
builder
|
||||
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
|
||||
.unwrap();
|
||||
builder
|
||||
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let transcript_commit = builder.build().unwrap();
|
||||
|
||||
@@ -144,8 +199,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
|
||||
let config = builder.build().unwrap();
|
||||
let transcript = prover.transcript().clone();
|
||||
let output = prover.prove(&config).await.unwrap();
|
||||
prover.close().await.unwrap();
|
||||
let output = prover.prove(&config, &mut verifier_socket).await.unwrap();
|
||||
prover.close(&mut verifier_socket).await.unwrap();
|
||||
|
||||
(transcript, output)
|
||||
}
|
||||
@@ -154,6 +209,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: T,
|
||||
) -> VerifierOutput {
|
||||
let mut socket = socket.compat();
|
||||
|
||||
let verifier = Verifier::new(
|
||||
VerifierConfig::builder()
|
||||
.root_store(RootCertStore {
|
||||
@@ -164,18 +221,24 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
);
|
||||
|
||||
let verifier = verifier
|
||||
.commit(socket.compat())
|
||||
.commit(&mut socket)
|
||||
.await
|
||||
.unwrap()
|
||||
.accept()
|
||||
.accept(&mut socket)
|
||||
.await
|
||||
.unwrap()
|
||||
.run()
|
||||
.run(&mut socket)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (output, verifier) = verifier.verify().await.unwrap().accept().await.unwrap();
|
||||
verifier.close().await.unwrap();
|
||||
let (output, verifier) = verifier
|
||||
.verify(&mut socket)
|
||||
.await
|
||||
.unwrap()
|
||||
.accept(&mut socket)
|
||||
.await
|
||||
.unwrap();
|
||||
verifier.close(&mut socket).await.unwrap();
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
@@ -23,9 +23,9 @@ no-bundler = ["web-spawn/no-bundler"]
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
|
||||
tlsn-server-fixture-certs = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
|
||||
async_io_stream = { version = "0.3" }
|
||||
bincode = { workspace = true }
|
||||
console_error_panic_hook = { version = "0.1" }
|
||||
enum-try-as-inner = { workspace = true }
|
||||
|
||||
@@ -2,11 +2,11 @@ mod config;
|
||||
|
||||
pub use config::ProverConfig;
|
||||
|
||||
use async_io_stream::IoStream;
|
||||
use enum_try_as_inner::EnumTryAsInner;
|
||||
use futures::TryFutureExt;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::body::Bytes;
|
||||
use tls_client_async::TlsConnection;
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
@@ -14,13 +14,13 @@ use tlsn::{
|
||||
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
|
||||
},
|
||||
connection::ServerName,
|
||||
prover::{state, Prover},
|
||||
prover::{state, Prover, TlsConnection},
|
||||
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
|
||||
};
|
||||
use tracing::info;
|
||||
use wasm_bindgen::{prelude::*, JsError};
|
||||
use wasm_bindgen_futures::spawn_local;
|
||||
use ws_stream_wasm::WsMeta;
|
||||
use ws_stream_wasm::{WsMeta, WsStreamIo};
|
||||
|
||||
use crate::{io::FuturesIo, types::*};
|
||||
|
||||
@@ -36,8 +36,14 @@ pub struct JsProver {
|
||||
#[derive_err(Debug)]
|
||||
enum State {
|
||||
Initialized(Prover<state::Initialized>),
|
||||
CommitAccepted(Prover<state::CommitAccepted>),
|
||||
Committed(Prover<state::Committed>),
|
||||
CommitAccepted {
|
||||
prover: Prover<state::CommitAccepted>,
|
||||
verifier_conn: IoStream<WsStreamIo, Vec<u8>>,
|
||||
},
|
||||
Committed {
|
||||
prover: Prover<state::Committed>,
|
||||
verifier_conn: IoStream<WsStreamIo, Vec<u8>>,
|
||||
},
|
||||
Complete,
|
||||
Error,
|
||||
}
|
||||
@@ -96,12 +102,16 @@ impl JsProver {
|
||||
info!("connecting to verifier");
|
||||
|
||||
let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?;
|
||||
let mut verifier_conn = verifier_conn.into_io();
|
||||
|
||||
info!("connected to verifier");
|
||||
|
||||
let prover = prover.commit(config, verifier_conn.into_io()).await?;
|
||||
let prover = prover.commit(config, &mut verifier_conn).await?;
|
||||
|
||||
self.state = State::CommitAccepted(prover);
|
||||
self.state = State::CommitAccepted {
|
||||
prover,
|
||||
verifier_conn,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -112,7 +122,7 @@ impl JsProver {
|
||||
ws_proxy_url: &str,
|
||||
request: HttpRequest,
|
||||
) -> Result<HttpResponse> {
|
||||
let prover = self.state.take().try_into_commit_accepted()?;
|
||||
let (prover, mut verifier_conn) = self.state.take().try_into_commit_accepted()?;
|
||||
|
||||
let mut builder = TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(
|
||||
@@ -145,35 +155,41 @@ impl JsProver {
|
||||
info!("connecting to server");
|
||||
|
||||
let (_, server_conn) = WsMeta::connect(ws_proxy_url, None).await?;
|
||||
let mut server_conn = server_conn.into_io();
|
||||
|
||||
info!("connected to server");
|
||||
|
||||
let (tls_conn, prover_fut) = prover.connect(config, server_conn.into_io()).await?;
|
||||
let (tls_conn, prover) = prover.setup(config)?;
|
||||
let mut prover = prover.connect(&mut server_conn, &mut verifier_conn);
|
||||
|
||||
info!("sending request");
|
||||
|
||||
let (response, prover) = futures::try_join!(
|
||||
let (response, _) = futures::try_join!(
|
||||
send_request(tls_conn, request),
|
||||
prover_fut.map_err(Into::into)
|
||||
(&mut prover).map_err(Into::into)
|
||||
)?;
|
||||
let prover = prover.finish()?;
|
||||
|
||||
info!("response received");
|
||||
|
||||
self.state = State::Committed(prover);
|
||||
self.state = State::Committed {
|
||||
prover,
|
||||
verifier_conn,
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Returns the transcript.
|
||||
pub fn transcript(&self) -> Result<Transcript> {
|
||||
let prover = self.state.try_as_committed()?;
|
||||
let (prover, _) = self.state.try_as_committed()?;
|
||||
|
||||
Ok(Transcript::from(prover.transcript()))
|
||||
}
|
||||
|
||||
/// Reveals data to the verifier and finalizes the protocol.
|
||||
pub async fn reveal(&mut self, reveal: Reveal) -> Result<()> {
|
||||
let mut prover = self.state.take().try_into_committed()?;
|
||||
let (mut prover, mut verifier_conn) = self.state.take().try_into_committed()?;
|
||||
|
||||
info!("revealing data");
|
||||
|
||||
@@ -193,8 +209,8 @@ impl JsProver {
|
||||
|
||||
let config = builder.build()?;
|
||||
|
||||
prover.prove(&config).await?;
|
||||
prover.close().await?;
|
||||
prover.prove(&config, &mut verifier_conn).await?;
|
||||
prover.close(&mut verifier_conn).await?;
|
||||
|
||||
info!("Finalized");
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ mod config;
|
||||
|
||||
pub use config::VerifierConfig;
|
||||
|
||||
use async_io_stream::IoStream;
|
||||
use enum_try_as_inner::EnumTryAsInner;
|
||||
use tlsn::{
|
||||
config::tls_commit::TlsCommitProtocolConfig,
|
||||
@@ -12,7 +13,7 @@ use tlsn::{
|
||||
};
|
||||
use tracing::info;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use ws_stream_wasm::{WsMeta, WsStream};
|
||||
use ws_stream_wasm::{WsMeta, WsStreamIo};
|
||||
|
||||
use crate::types::VerifierOutput;
|
||||
|
||||
@@ -26,9 +27,13 @@ pub struct JsVerifier {
|
||||
|
||||
#[derive(EnumTryAsInner)]
|
||||
#[derive_err(Debug)]
|
||||
#[allow(unused_assignments)]
|
||||
enum State {
|
||||
Initialized(Verifier<state::Initialized>),
|
||||
Connected((Verifier<state::Initialized>, WsStream)),
|
||||
Connected {
|
||||
verifier: Verifier<state::Initialized>,
|
||||
prover_conn: IoStream<WsStreamIo, Vec<u8>>,
|
||||
},
|
||||
Complete,
|
||||
Error,
|
||||
}
|
||||
@@ -66,19 +71,23 @@ impl JsVerifier {
|
||||
info!("Connecting to prover");
|
||||
|
||||
let (_, prover_conn) = WsMeta::connect(prover_url, None).await?;
|
||||
let prover_conn = prover_conn.into_io();
|
||||
|
||||
info!("Connected to prover");
|
||||
|
||||
self.state = State::Connected((verifier, prover_conn));
|
||||
self.state = State::Connected {
|
||||
verifier,
|
||||
prover_conn,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verifies the connection and finalizes the protocol.
|
||||
pub async fn verify(&mut self) -> Result<VerifierOutput> {
|
||||
let (verifier, prover_conn) = self.state.take().try_into_connected()?;
|
||||
let (verifier, mut prover_conn) = self.state.take().try_into_connected()?;
|
||||
|
||||
let verifier = verifier.commit(prover_conn.into_io()).await?;
|
||||
let verifier = verifier.commit(&mut prover_conn).await?;
|
||||
let request = verifier.request();
|
||||
|
||||
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol() else {
|
||||
@@ -98,11 +107,15 @@ impl JsVerifier {
|
||||
};
|
||||
|
||||
if reject.is_some() {
|
||||
verifier.reject(reject).await?;
|
||||
verifier.reject(&mut prover_conn, reject).await?;
|
||||
return Err(JsError::new("protocol configuration rejected"));
|
||||
}
|
||||
|
||||
let verifier = verifier.accept().await?.run().await?;
|
||||
let verifier = verifier
|
||||
.accept(&mut prover_conn)
|
||||
.await?
|
||||
.run(&mut prover_conn)
|
||||
.await?;
|
||||
|
||||
let sent = verifier
|
||||
.tls_transcript()
|
||||
@@ -129,8 +142,12 @@ impl JsVerifier {
|
||||
},
|
||||
};
|
||||
|
||||
let (output, verifier) = verifier.verify().await?.accept().await?;
|
||||
verifier.close().await?;
|
||||
let (output, verifier) = verifier
|
||||
.verify(&mut prover_conn)
|
||||
.await?
|
||||
.accept(&mut prover_conn)
|
||||
.await?;
|
||||
verifier.close(&mut prover_conn).await?;
|
||||
|
||||
self.state = State::Complete;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user