refactor: sans-io TLS IO and flexible socket API

This commit is contained in:
th4s
2026-01-07 08:14:04 -08:00
committed by sinu
parent 2101285f7f
commit 0a086ee91f
39 changed files with 1931 additions and 1626 deletions

128
Cargo.lock generated
View File

@@ -174,9 +174,9 @@ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]] [[package]]
name = "alloy-consensus" name = "alloy-consensus"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6440213a22df93a87ed512d2f668e7dc1d62a05642d107f82d61edc9e12370" checksum = "2e318e25fb719e747a7e8db1654170fc185024f3ed5b10f86c08d448a912f6e2"
dependencies = [ dependencies = [
"alloy-eips", "alloy-eips",
"alloy-primitives", "alloy-primitives",
@@ -201,9 +201,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-consensus-any" name = "alloy-consensus-any"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15d0bea09287942405c4f9d2a4f22d1e07611c2dbd9d5bf94b75366340f9e6e0" checksum = "364380a845193a317bcb7a5398fc86cdb66c47ebe010771dde05f6869bf9e64a"
dependencies = [ dependencies = [
"alloy-consensus", "alloy-consensus",
"alloy-eips", "alloy-eips",
@@ -253,9 +253,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-eips" name = "alloy-eips"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bd2c7ae05abcab4483ce821f12f285e01c0b33804e6883dd9ca1569a87ee2be" checksum = "a4c4d7c5839d9f3a467900c625416b24328450c65702eb3d8caff8813e4d1d33"
dependencies = [ dependencies = [
"alloy-eip2124", "alloy-eip2124",
"alloy-eip2930", "alloy-eip2930",
@@ -288,9 +288,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-json-rpc" name = "alloy-json-rpc"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "003f46c54f22854a32b9cc7972660a476968008ad505427eabab49225309ec40" checksum = "f72cf87cda808e593381fb9f005ffa4d2475552b7a6c5ac33d087bf77d82abd0"
dependencies = [ dependencies = [
"alloy-primitives", "alloy-primitives",
"alloy-sol-types", "alloy-sol-types",
@@ -303,9 +303,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-network" name = "alloy-network"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f4029954d9406a40979f3a3b46950928a0fdcfe3ea8a9b0c17490d57e8aa0e3" checksum = "12aeb37b6f2e61b93b1c3d34d01ee720207c76fe447e2a2c217e433ac75b17f5"
dependencies = [ dependencies = [
"alloy-consensus", "alloy-consensus",
"alloy-consensus-any", "alloy-consensus-any",
@@ -329,9 +329,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-network-primitives" name = "alloy-network-primitives"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7805124ad69e57bbae7731c9c344571700b2a18d351bda9e0eba521c991d1bcb" checksum = "abd29ace62872083e30929cd9b282d82723196d196db589f3ceda67edcc05552"
dependencies = [ dependencies = [
"alloy-consensus", "alloy-consensus",
"alloy-eips", "alloy-eips",
@@ -391,9 +391,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-rpc-types-any" name = "alloy-rpc-types-any"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b43c1622aac2508d528743fd4cfdac1dea92d5a8fa894038488ff7edd0af0b32" checksum = "6a63fb40ed24e4c92505f488f9dd256e2afaed17faa1b7a221086ebba74f4122"
dependencies = [ dependencies = [
"alloy-consensus-any", "alloy-consensus-any",
"alloy-rpc-types-eth", "alloy-rpc-types-eth",
@@ -402,9 +402,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-rpc-types-eth" name = "alloy-rpc-types-eth"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed5fafb741c19b3cca4cdd04fa215c89413491f9695a3e928dee2ae5657f607e" checksum = "9eae0c7c40da20684548cbc8577b6b7447f7bf4ddbac363df95e3da220e41e72"
dependencies = [ dependencies = [
"alloy-consensus", "alloy-consensus",
"alloy-consensus-any", "alloy-consensus-any",
@@ -423,9 +423,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-serde" name = "alloy-serde"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6f180c399ca7c1e2fe17ea58343910cad0090878a696ff5a50241aee12fc529" checksum = "c0df1987ed0ff2d0159d76b52e7ddfc4e4fbddacc54d2fbee765e0d14d7c01b5"
dependencies = [ dependencies = [
"alloy-primitives", "alloy-primitives",
"serde", "serde",
@@ -434,9 +434,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-signer" name = "alloy-signer"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecc39ad2c0a3d2da8891f4081565780703a593f090f768f884049aa3aa929cbc" checksum = "6ff69deedee7232d7ce5330259025b868c5e6a52fa8dffda2c861fb3a5889b24"
dependencies = [ dependencies = [
"alloy-primitives", "alloy-primitives",
"async-trait", "async-trait",
@@ -449,9 +449,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-signer-local" name = "alloy-signer-local"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "930e17cb1e46446a193a593a3bfff8d0ecee4e510b802575ebe300ae2e43ef75" checksum = "72cfe0be3ec5a8c1a46b2e5a7047ed41121d360d97f4405bb7c1c784880c86cb"
dependencies = [ dependencies = [
"alloy-consensus", "alloy-consensus",
"alloy-network", "alloy-network",
@@ -551,9 +551,9 @@ dependencies = [
[[package]] [[package]]
name = "alloy-tx-macros" name = "alloy-tx-macros"
version = "1.1.2" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae109e33814b49fc0a62f2528993aa8a2dd346c26959b151f05441dc0b9da292" checksum = "333544408503f42d7d3792bfc0f7218b643d968a03d2c0ed383ae558fb4a76d0"
dependencies = [ dependencies = [
"darling 0.21.3", "darling 0.21.3",
"proc-macro2", "proc-macro2",
@@ -1783,9 +1783,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.48" version = "1.2.49"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215"
dependencies = [ dependencies = [
"find-msvc-tools", "find-msvc-tools",
"shlex", "shlex",
@@ -3220,6 +3220,14 @@ dependencies = [
"syn 2.0.111", "syn 2.0.111",
] ]
[[package]]
name = "futures-plex"
version = "0.1.0"
source = "git+https://github.com/tlsnotary/tlsn-utils?rev=c210f2f#c210f2fdd0a5d71c3e217fa03127c9f616314836"
dependencies = [
"futures",
]
[[package]] [[package]]
name = "futures-rustls" name = "futures-rustls"
version = "0.25.1" version = "0.25.1"
@@ -3766,9 +3774,9 @@ checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
[[package]] [[package]]
name = "icu_properties" name = "icu_properties"
version = "2.1.1" version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec"
dependencies = [ dependencies = [
"icu_collections", "icu_collections",
"icu_locale_core", "icu_locale_core",
@@ -3780,9 +3788,9 @@ dependencies = [
[[package]] [[package]]
name = "icu_properties_data" name = "icu_properties_data"
version = "2.1.1" version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af"
[[package]] [[package]]
name = "icu_provider" name = "icu_provider"
@@ -4302,9 +4310,9 @@ dependencies = [
[[package]] [[package]]
name = "mio" name = "mio"
version = "1.1.0" version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc"
dependencies = [ dependencies = [
"libc", "libc",
"wasi", "wasi",
@@ -5486,7 +5494,7 @@ version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983"
dependencies = [ dependencies = [
"toml_edit 0.23.7", "toml_edit 0.23.9",
] ]
[[package]] [[package]]
@@ -5901,9 +5909,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.12.24" version = "0.12.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" checksum = "b6eff9328d40131d43bd911d42d79eb6a47312002a4daefc9e37f17e74a7701a"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
@@ -5930,7 +5938,7 @@ dependencies = [
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tower", "tower",
"tower-http 0.6.7", "tower-http 0.6.8",
"tower-service", "tower-service",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
@@ -6788,9 +6796,9 @@ dependencies = [
[[package]] [[package]]
name = "simd-adler32" name = "simd-adler32"
version = "0.3.7" version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
[[package]] [[package]]
name = "sized-chunks" name = "sized-chunks"
@@ -6854,9 +6862,9 @@ checksum = "bceb57dc07c92cdae60f5b27b3fa92ecaaa42fe36c55e22dbfb0b44893e0b1f7"
[[package]] [[package]]
name = "sourcemap" name = "sourcemap"
version = "9.2.2" version = "9.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e22afbcb92ce02d23815b9795523c005cb9d3c214f8b7a66318541c240ea7935" checksum = "37ccaaa78a0ca68b20f8f711eaa2522a00131c48a3de5b892ca5c36cec1ce9bb"
dependencies = [ dependencies = [
"base64-simd", "base64-simd",
"bitvec", "bitvec",
@@ -7032,9 +7040,9 @@ dependencies = [
[[package]] [[package]]
name = "sys_traits" name = "sys_traits"
version = "0.1.19" version = "0.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1495a604cd38eeb30c408724966cd31ca1b68b5a97e3afc474c0d719bfeec5a" checksum = "6b61f4a25d0baba25511bed00c39c199d9a19cfd8107f4472724b72a84f530b1"
dependencies = [ dependencies = [
"sys_traits_macros", "sys_traits_macros",
] ]
@@ -7255,6 +7263,7 @@ dependencies = [
"aes 0.8.4", "aes 0.8.4",
"ctr 0.9.2", "ctr 0.9.2",
"futures", "futures",
"futures-plex",
"ghash 0.5.1", "ghash 0.5.1",
"http-body-util", "http-body-util",
"hyper", "hyper",
@@ -7272,6 +7281,7 @@ dependencies = [
"mpz-zk", "mpz-zk",
"once_cell", "once_cell",
"opaque-debug", "opaque-debug",
"pin-project-lite",
"rand 0.9.2", "rand 0.9.2",
"rangeset 0.4.0", "rangeset 0.4.0",
"rstest", "rstest",
@@ -7289,7 +7299,6 @@ dependencies = [
"tlsn-server-fixture", "tlsn-server-fixture",
"tlsn-server-fixture-certs", "tlsn-server-fixture-certs",
"tlsn-tls-client", "tlsn-tls-client",
"tlsn-tls-client-async",
"tlsn-tls-core", "tlsn-tls-core",
"tokio", "tokio",
"tokio-util", "tokio-util",
@@ -7608,7 +7617,6 @@ dependencies = [
"tlsn-key-exchange", "tlsn-key-exchange",
"tlsn-tls-backend", "tlsn-tls-backend",
"tlsn-tls-client", "tlsn-tls-client",
"tlsn-tls-client-async",
"tlsn-tls-core", "tlsn-tls-core",
"tokio", "tokio",
"tokio-util", "tokio-util",
@@ -7635,7 +7643,7 @@ dependencies = [
"tokio", "tokio",
"tokio-util", "tokio-util",
"tower", "tower",
"tower-http 0.6.7", "tower-http 0.6.8",
"tower-service", "tower-service",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@@ -7683,26 +7691,6 @@ dependencies = [
"webpki-roots 1.0.4", "webpki-roots 1.0.4",
] ]
[[package]]
name = "tlsn-tls-client-async"
version = "0.1.0-alpha.14-pre"
dependencies = [
"bytes",
"futures",
"http-body-util",
"hyper",
"hyper-util",
"rstest",
"rustls-pki-types",
"rustls-webpki 0.103.8",
"thiserror 1.0.69",
"tls-server-fixture",
"tlsn-tls-client",
"tokio",
"tokio-util",
"tracing",
]
[[package]] [[package]]
name = "tlsn-tls-core" name = "tlsn-tls-core"
version = "0.1.0-alpha.14-pre" version = "0.1.0-alpha.14-pre"
@@ -7731,6 +7719,7 @@ source = "git+https://github.com/tlsnotary/tlsn-utils?rev=6168663#6168663495281f
name = "tlsn-wasm" name = "tlsn-wasm"
version = "0.1.0-alpha.14-pre" version = "0.1.0-alpha.14-pre"
dependencies = [ dependencies = [
"async_io_stream",
"bincode 1.3.3", "bincode 1.3.3",
"console_error_panic_hook", "console_error_panic_hook",
"enum-try-as-inner", "enum-try-as-inner",
@@ -7749,7 +7738,6 @@ dependencies = [
"tlsn", "tlsn",
"tlsn-core", "tlsn-core",
"tlsn-server-fixture-certs", "tlsn-server-fixture-certs",
"tlsn-tls-client-async",
"tlsn-tls-core", "tlsn-tls-core",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@@ -7896,9 +7884,9 @@ dependencies = [
[[package]] [[package]]
name = "toml_edit" name = "toml_edit"
version = "0.23.7" version = "0.23.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" checksum = "5d7cbc3b4b49633d57a0509303158ca50de80ae32c265093b24c414705807832"
dependencies = [ dependencies = [
"indexmap 2.12.1", "indexmap 2.12.1",
"toml_datetime 0.7.3", "toml_datetime 0.7.3",
@@ -7964,9 +7952,9 @@ dependencies = [
[[package]] [[package]]
name = "tower-http" name = "tower-http"
version = "0.6.7" version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"bytes", "bytes",

View File

@@ -13,7 +13,6 @@ members = [
"crates/server-fixture/server", "crates/server-fixture/server",
"crates/tls/backend", "crates/tls/backend",
"crates/tls/client", "crates/tls/client",
"crates/tls/client-async",
"crates/tls/core", "crates/tls/core",
"crates/mpc-tls", "crates/mpc-tls",
"crates/tls/server-fixture", "crates/tls/server-fixture",
@@ -57,7 +56,6 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" }
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" } tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
tlsn-tls-backend = { path = "crates/tls/backend" } tlsn-tls-backend = { path = "crates/tls/backend" }
tlsn-tls-client = { path = "crates/tls/client" } tlsn-tls-client = { path = "crates/tls/client" }
tlsn-tls-client-async = { path = "crates/tls/client-async" }
tlsn-tls-core = { path = "crates/tls/core" } tlsn-tls-core = { path = "crates/tls/core" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
tlsn-harness-core = { path = "crates/harness/core" } tlsn-harness-core = { path = "crates/harness/core" }
@@ -82,6 +80,7 @@ mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" } mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" } mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
futures-plex = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "c210f2f" }
rangeset = { version = "0.4" } rangeset = { version = "0.4" }
serio = { version = "0.2" } serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" } spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }

View File

@@ -87,13 +87,15 @@ async fn main() -> Result<()> {
} }
async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>( async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: S, verifier_socket: S,
req_tx: Sender<AttestationRequest>, req_tx: Sender<AttestationRequest>,
resp_rx: Receiver<Attestation>, resp_rx: Receiver<Attestation>,
uri: &str, uri: &str,
extra_headers: Vec<(&str, &str)>, extra_headers: Vec<(&str, &str)>,
example_type: &ExampleType, example_type: &ExampleType,
) -> Result<()> { ) -> Result<()> {
let mut verifier_socket = verifier_socket.compat();
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into()); let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT") let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer")) .map(|port| port.parse().expect("port should be valid integer"))
@@ -115,37 +117,36 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.build()?, .build()?,
) )
.build()?, .build()?,
socket.compat(), &mut verifier_socket,
) )
.await?; .await?;
// Open a TCP connection to the server. // Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?; let client_socket = tokio::net::TcpStream::connect((server_host, server_port))
.await?
.compat();
// Bind the prover to the server connection. // Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover let (tls_connection, prover) = prover.setup(
.connect( TlsClientConfig::builder()
TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) // Create a root certificate store with the server-fixture's self-signed
// Create a root certificate store with the server-fixture's self-signed // certificate. This is only required for offline testing with the
// certificate. This is only required for offline testing with the // server-fixture.
// server-fixture. .root_store(RootCertStore {
.root_store(RootCertStore { roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], })
}) // (Optional) Set up TLS client authentication if required by the server.
// (Optional) Set up TLS client authentication if required by the server. .client_auth((
.client_auth(( vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
vec![CertificateDer(CLIENT_CERT_DER.to_vec())], PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
PrivateKeyDer(CLIENT_KEY_DER.to_vec()), ))
)) .build()?,
.build()?, )?;
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat()); let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the prover task to be run concurrently in the background. // Spawn the prover task to be run concurrently in the background.
let prover_task = tokio::spawn(prover_fut); let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
// Attach the hyper HTTP client to the connection. // Attach the hyper HTTP client to the connection.
let (mut request_sender, connection) = let (mut request_sender, connection) =
@@ -180,7 +181,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
assert!(response.status() == StatusCode::OK); assert!(response.status() == StatusCode::OK);
// The prover task should be done now, so we can await it. // The prover task should be done now, so we can await it.
let prover = prover_task.await??; let (prover, _, verifier_socket) = prover_task.await??;
// Parse the HTTP transcript. // Parse the HTTP transcript.
let transcript = HttpTranscript::parse(prover.transcript())?; let transcript = HttpTranscript::parse(prover.transcript())?;
@@ -222,7 +223,8 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let request_config = builder.build()?; let request_config = builder.build()?;
let (attestation, secrets) = notarize(prover, &request_config, req_tx, resp_rx).await?; let (attestation, secrets) =
notarize(prover, &request_config, verifier_socket, req_tx, resp_rx).await?;
// Write the attestation to disk. // Write the attestation to disk.
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation"); let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
@@ -242,9 +244,10 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
Ok(()) Ok(())
} }
async fn notarize( async fn notarize<S: futures::AsyncRead + futures::AsyncWrite + Send + Unpin>(
mut prover: Prover<Committed>, mut prover: Prover<Committed>,
config: &RequestConfig, config: &RequestConfig,
mut verifier_socket: S,
request_tx: Sender<AttestationRequest>, request_tx: Sender<AttestationRequest>,
attestation_rx: Receiver<Attestation>, attestation_rx: Receiver<Attestation>,
) -> Result<(Attestation, Secrets)> { ) -> Result<(Attestation, Secrets)> {
@@ -260,11 +263,13 @@ async fn notarize(
transcript_commitments, transcript_commitments,
transcript_secrets, transcript_secrets,
.. ..
} = prover.prove(&disclosure_config).await?; } = prover
.prove(&disclosure_config, &mut verifier_socket)
.await?;
let transcript = prover.transcript().clone(); let transcript = prover.transcript().clone();
let tls_transcript = prover.tls_transcript().clone(); let tls_transcript = prover.tls_transcript().clone();
prover.close().await?; prover.close(&mut verifier_socket).await?;
// Build an attestation request. // Build an attestation request.
let mut builder = AttestationRequest::builder(config); let mut builder = AttestationRequest::builder(config);
@@ -307,10 +312,12 @@ async fn notarize(
} }
async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>( async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: S, prover_socket: S,
request_rx: Receiver<AttestationRequest>, request_rx: Receiver<AttestationRequest>,
attestation_tx: Sender<Attestation>, attestation_tx: Sender<Attestation>,
) -> Result<()> { ) -> Result<()> {
let mut prover_socket = prover_socket.compat();
// Create a root certificate store with the server-fixture's self-signed // Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the // certificate. This is only required for offline testing with the
// server-fixture. // server-fixture.
@@ -322,11 +329,11 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.unwrap(); .unwrap();
let verifier = Verifier::new(verifier_config) let verifier = Verifier::new(verifier_config)
.commit(socket.compat()) .commit(&mut prover_socket)
.await? .await?
.accept() .accept(&mut prover_socket)
.await? .await?
.run() .run(&mut prover_socket)
.await?; .await?;
let ( let (
@@ -336,11 +343,15 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.. ..
}, },
verifier, verifier,
) = verifier.verify().await?.accept().await?; ) = verifier
.verify(&mut prover_socket)
.await?
.accept(&mut prover_socket)
.await?;
let tls_transcript = verifier.tls_transcript().clone(); let tls_transcript = verifier.tls_transcript().clone();
verifier.close().await?; verifier.close(&mut prover_socket).await?;
let sent_len = tls_transcript let sent_len = tls_transcript
.sent() .sent()

View File

@@ -73,6 +73,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
server_addr: &SocketAddr, server_addr: &SocketAddr,
uri: &str, uri: &str,
) -> Result<()> { ) -> Result<()> {
let mut verifier_socket = verifier_socket.compat();
let uri = uri.parse::<Uri>().unwrap(); let uri = uri.parse::<Uri>().unwrap();
assert_eq!(uri.scheme().unwrap().as_str(), "https"); assert_eq!(uri.scheme().unwrap().as_str(), "https");
let server_domain = uri.authority().unwrap().host(); let server_domain = uri.authority().unwrap().host();
@@ -93,32 +95,30 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.build()?, .build()?,
) )
.build()?, .build()?,
verifier_socket.compat(), &mut verifier_socket,
) )
.await?; .await?;
// Open a TCP connection to the server. // Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?; let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
// Bind the prover to the server connection. // Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover let (tls_connection, prover) = prover.setup(
.connect( TlsClientConfig::builder()
TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) // Create a root certificate store with the server-fixture's self-signed
// Create a root certificate store with the server-fixture's self-signed // certificate. This is only required for offline testing with the
// certificate. This is only required for offline testing with the // server-fixture.
// server-fixture. .root_store(RootCertStore {
.root_store(RootCertStore { roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], })
}) .build()?,
.build()?, )?;
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat()); let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background. // Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut); let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
// MPC-TLS Handshake. // MPC-TLS Handshake.
let (mut request_sender, connection) = let (mut request_sender, connection) =
@@ -140,7 +140,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
assert!(response.status() == StatusCode::OK); assert!(response.status() == StatusCode::OK);
// Create proof for the Verifier. // Create proof for the Verifier.
let mut prover = prover_task.await??; let (mut prover, _, mut verifier_socket) = prover_task.await??;
let mut builder = ProveConfig::builder(prover.transcript()); let mut builder = ProveConfig::builder(prover.transcript());
@@ -173,8 +173,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build()?; let config = builder.build()?;
prover.prove(&config).await?; prover.prove(&config, &mut verifier_socket).await?;
prover.close().await?; prover.close(&mut verifier_socket).await?;
Ok(()) Ok(())
} }
@@ -183,6 +183,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>( async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T, socket: T,
) -> Result<PartialTranscript> { ) -> Result<PartialTranscript> {
let mut socket = socket.compat();
// Create a root certificate store with the server-fixture's self-signed // Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the // certificate. This is only required for offline testing with the
// server-fixture. // server-fixture.
@@ -194,7 +196,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let verifier = Verifier::new(verifier_config); let verifier = Verifier::new(verifier_config);
// Validate the proposed configuration and then run the TLS commitment protocol. // Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(socket.compat()).await?; let verifier = verifier.commit(&mut socket).await?;
// This is the opportunity to ensure the prover does not attempt to overload the // This is the opportunity to ensure the prover does not attempt to overload the
// verifier. // verifier.
@@ -212,21 +214,21 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
}; };
if reject.is_some() { if reject.is_some() {
verifier.reject(reject).await?; verifier.reject(&mut socket, reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected")); return Err(anyhow::anyhow!("protocol configuration rejected"));
} }
// Runs the TLS commitment protocol to completion. // Runs the TLS commitment protocol to completion.
let verifier = verifier.accept().await?.run().await?; let verifier = verifier.accept(&mut socket).await?.run(&mut socket).await?;
// Validate the proving request and then verify. // Validate the proving request and then verify.
let verifier = verifier.verify().await?; let verifier = verifier.verify(&mut socket).await?;
if !verifier.request().server_identity() { if !verifier.request().server_identity() {
let verifier = verifier let verifier = verifier
.reject(Some("expecting to verify the server name")) .reject(&mut socket, Some("expecting to verify the server name"))
.await?; .await?;
verifier.close().await?; verifier.close(&mut socket).await?;
return Err(anyhow::anyhow!("prover did not reveal the server name")); return Err(anyhow::anyhow!("prover did not reveal the server name"));
} }
@@ -237,9 +239,9 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.. ..
}, },
verifier, verifier,
) = verifier.accept().await?; ) = verifier.accept(&mut socket).await?;
verifier.close().await?; verifier.close(&mut socket).await?;
let server_name = server_name.expect("prover should have revealed server name"); let server_name = server_name.expect("prover should have revealed server name");
let transcript = transcript.expect("prover should have revealed transcript data"); let transcript = transcript.expect("prover should have revealed transcript data");

View File

@@ -31,11 +31,10 @@ async fn main() -> Result<()> {
// Connect prover and verifier. // Connect prover and verifier.
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23); let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
let (prover_extra_socket, verifier_extra_socket) = tokio::io::duplex(1 << 23);
let (_, transcript) = tokio::try_join!( let (_, transcript) = tokio::try_join!(
prover(prover_socket, prover_extra_socket, &server_addr, &uri), prover(prover_socket, &server_addr, &uri),
verifier(verifier_socket, verifier_extra_socket) verifier(verifier_socket)
)?; )?;
println!("---"); println!("---");

View File

@@ -46,15 +46,15 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument; use tracing::instrument;
#[instrument(skip(verifier_socket, verifier_extra_socket))] #[instrument(skip(verifier_socket))]
pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>( pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T, verifier_socket: T,
mut verifier_extra_socket: T,
server_addr: &SocketAddr, server_addr: &SocketAddr,
uri: &str, uri: &str,
) -> Result<()> { ) -> Result<()> {
let uri = uri.parse::<Uri>()?; let mut verifier_socket = verifier_socket.compat();
let uri = uri.parse::<Uri>()?;
if uri.scheme().map(|s| s.as_str()) != Some("https") { if uri.scheme().map(|s| s.as_str()) != Some("https") {
return Err(anyhow::anyhow!("URI must use HTTPS scheme")); return Err(anyhow::anyhow!("URI must use HTTPS scheme"));
} }
@@ -80,32 +80,29 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.build()?, .build()?,
) )
.build()?, .build()?,
verifier_socket.compat(), &mut verifier_socket,
) )
.await?; .await?;
// Open a TCP connection to the server. // Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?; let client_socket = tokio::net::TcpStream::connect(server_addr).await?.compat();
// Bind the prover to the server connection. // Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover let (tls_connection, prover) = prover.setup(
.connect( TlsClientConfig::builder()
TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) // Create a root certificate store with the server-fixture's self-signed
// Create a root certificate store with the server-fixture's self-signed // certificate. This is only required for offline testing with the
// certificate. This is only required for offline testing with the // server-fixture.
// server-fixture. .root_store(RootCertStore {
.root_store(RootCertStore { roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], })
}) .build()?,
.build()?, )?;
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat()); let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background. // Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut); let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
// MPC-TLS Handshake. // MPC-TLS Handshake.
let (mut request_sender, connection) = let (mut request_sender, connection) =
@@ -133,7 +130,7 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
} }
// Create proof for the Verifier. // Create proof for the Verifier.
let mut prover = prover_task.await??; let (mut prover, _, mut verifier_socket) = prover_task.await??;
let transcript = prover.transcript().clone(); let transcript = prover.transcript().clone();
let mut prove_config_builder = ProveConfig::builder(&transcript); let mut prove_config_builder = ProveConfig::builder(&transcript);
@@ -167,8 +164,8 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let prove_config = prove_config_builder.build()?; let prove_config = prove_config_builder.build()?;
// MPC-TLS prove // MPC-TLS prove
let prover_output = prover.prove(&prove_config).await?; let prover_output = prover.prove(&prove_config, &mut verifier_socket).await?;
prover.close().await?; prover.close(&mut verifier_socket).await?;
// Prove birthdate is more than 18 years ago. // Prove birthdate is more than 18 years ago.
let received_commitments = received_commitments(&prover_output.transcript_commitments); let received_commitments = received_commitments(&prover_output.transcript_commitments);
@@ -184,8 +181,10 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
// Sent zk proof bundle to verifier // Sent zk proof bundle to verifier
let serialized_proof = bincode::serialize(&proof_bundle)?; let serialized_proof = bincode::serialize(&proof_bundle)?;
verifier_extra_socket.write_all(&serialized_proof).await?;
verifier_extra_socket.shutdown().await?; let mut verifier_socket = verifier_socket.into_inner();
verifier_socket.write_all(&serialized_proof).await?;
verifier_socket.shutdown().await?;
Ok(()) Ok(())
} }

View File

@@ -20,11 +20,12 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt; use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::instrument; use tracing::instrument;
#[instrument(skip(socket, extra_socket))] #[instrument(skip(prover_socket))]
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>( pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T, prover_socket: T,
mut extra_socket: T,
) -> Result<PartialTranscript> { ) -> Result<PartialTranscript> {
let mut prover_socket = prover_socket.compat();
let verifier = Verifier::new( let verifier = Verifier::new(
VerifierConfig::builder() VerifierConfig::builder()
// Create a root certificate store with the server-fixture's self-signed // Create a root certificate store with the server-fixture's self-signed
@@ -37,7 +38,7 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
); );
// Validate the proposed configuration and then run the TLS commitment protocol. // Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(socket.compat()).await?; let verifier = verifier.commit(&mut prover_socket).await?;
// This is the opportunity to ensure the prover does not attempt to overload the // This is the opportunity to ensure the prover does not attempt to overload the
// verifier. // verifier.
@@ -55,24 +56,29 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
}; };
if reject.is_some() { if reject.is_some() {
verifier.reject(reject).await?; verifier.reject(&mut prover_socket, reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected")); return Err(anyhow::anyhow!("protocol configuration rejected"));
} }
// Runs the TLS commitment protocol to completion. // Runs the TLS commitment protocol to completion.
let verifier = verifier.accept().await?.run().await?; let verifier = verifier
.accept(&mut prover_socket)
.await?
.run(&mut prover_socket)
.await?;
// Validate the proving request and then verify. // Validate the proving request and then verify.
let verifier = verifier.verify().await?; let verifier = verifier.verify(&mut prover_socket).await?;
let request = verifier.request(); let request = verifier.request();
if !request.server_identity() || request.reveal().is_none() { if !request.server_identity() || request.reveal().is_none() {
let verifier = verifier let verifier = verifier
.reject(Some( .reject(
"expecting to verify the server name and transcript data", &mut prover_socket,
)) Some("expecting to verify the server name and transcript data"),
)
.await?; .await?;
verifier.close().await?; verifier.close(&mut prover_socket).await?;
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"prover did not reveal the server name and transcript data" "prover did not reveal the server name and transcript data"
)); ));
@@ -86,10 +92,9 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
.. ..
}, },
verifier, verifier,
) = verifier.accept().await?; ) = verifier.accept(&mut prover_socket).await?;
verifier.close().await?;
verifier.close(&mut prover_socket).await?;
let server_name = server_name.expect("server name should be present"); let server_name = server_name.expect("server name should be present");
let transcript = transcript.expect("transcript should be present"); let transcript = transcript.expect("transcript should be present");
@@ -126,7 +131,9 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// Receive ZKProof information from prover // Receive ZKProof information from prover
let mut buf = Vec::new(); let mut buf = Vec::new();
extra_socket.read_to_end(&mut buf).await?;
let mut prover_socket = prover_socket.into_inner();
prover_socket.read_to_end(&mut buf).await?;
if buf.is_empty() { if buf.is_empty() {
return Err(anyhow::anyhow!("No ZK proof data received from prover")); return Err(anyhow::anyhow!("No ZK proof data received from prover"));

View File

@@ -23,7 +23,8 @@ use crate::{
}; };
pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<ProverMetrics> { pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<ProverMetrics> {
let verifier_io = Meter::new(provider.provide_proto_io().await?); let mut verifier_io = Meter::new(provider.provide_proto_io().await?);
let mut server_io = provider.provide_server_io().await?;
let sent = verifier_io.sent(); let sent = verifier_io.sent();
let recv = verifier_io.recv(); let recv = verifier_io.recv();
@@ -49,7 +50,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
.build() .build()
}?) }?)
.build()?, .build()?,
verifier_io, &mut verifier_io,
) )
.await?; .await?;
@@ -58,19 +59,18 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let uploaded_preprocess = sent.load(Ordering::Relaxed); let uploaded_preprocess = sent.load(Ordering::Relaxed);
let downloaded_preprocess = recv.load(Ordering::Relaxed); let downloaded_preprocess = recv.load(Ordering::Relaxed);
let (mut conn, prover_fut) = prover let (mut conn, prover) = prover.setup(
.connect( TlsClientConfig::builder()
TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) .root_store(RootCertStore {
.root_store(RootCertStore { roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], })
}) .build()?,
.build()?, )?;
provider.provide_server_io().await?,
)
.await?;
let (_, mut prover) = futures::try_join!( let mut prover = prover.connect(&mut server_io, &mut verifier_io);
futures::try_join!(
async { async {
let request = format!( let request = format!(
"GET /bytes?size={} HTTP/1.1\r\nConnection: close\r\nData: {}\r\n\r\n", "GET /bytes?size={} HTTP/1.1\r\nConnection: close\r\nData: {}\r\n\r\n",
@@ -87,8 +87,9 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
Ok(()) Ok(())
}, },
prover_fut.map_err(anyhow::Error::from) (&mut prover).map_err(anyhow::Error::from)
)?; )?;
let mut prover = prover.finish()?;
let time_online = time_start_online.elapsed().as_millis(); let time_online = time_start_online.elapsed().as_millis();
let uploaded_online = sent.load(Ordering::Relaxed) - uploaded_preprocess; let uploaded_online = sent.load(Ordering::Relaxed) - uploaded_preprocess;
@@ -118,8 +119,8 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let prove_config = builder.build()?; let prove_config = builder.build()?;
prover.prove(&prove_config).await?; prover.prove(&prove_config, &mut verifier_io).await?;
prover.close().await?; prover.close(&mut verifier_io).await?;
let time_total = time_start.elapsed().as_millis(); let time_total = time_start.elapsed().as_millis();

View File

@@ -11,6 +11,8 @@ use tlsn_server_fixture_certs::CA_CERT_DER;
use crate::IoProvider; use crate::IoProvider;
pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()> { pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()> {
let mut prover_io = provider.provide_proto_io().await?;
let verifier = Verifier::new( let verifier = Verifier::new(
VerifierConfig::builder() VerifierConfig::builder()
.root_store(RootCertStore { .root_store(RootCertStore {
@@ -22,12 +24,16 @@ pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()
let verifier = verifier let verifier = verifier
.commit(provider.provide_proto_io().await?) .commit(provider.provide_proto_io().await?)
.await? .await?
.accept() .accept(&mut prover_io)
.await? .await?
.run() .run(&mut prover_io)
.await?; .await?;
let (_, verifier) = verifier.verify().await?.accept().await?; let (_, verifier) = verifier
verifier.close().await?; .verify(&mut prover_io)
.await?
.accept(&mut prover_io)
.await?;
verifier.close(&mut prover_io).await?;
Ok(()) Ok(())
} }

View File

@@ -28,6 +28,8 @@ const MAX_RECV_DATA: usize = 1 << 11;
crate::test!("basic", prover, verifier); crate::test!("basic", prover, verifier);
async fn prover(provider: &IoProvider) { async fn prover(provider: &IoProvider) {
let mut verifier_io = provider.provide_proto_io().await.unwrap();
let prover = Prover::new(ProverConfig::builder().build().unwrap()) let prover = Prover::new(ProverConfig::builder().build().unwrap())
.commit( .commit(
TlsCommitConfig::builder() TlsCommitConfig::builder()
@@ -41,13 +43,15 @@ async fn prover(provider: &IoProvider) {
) )
.build() .build()
.unwrap(), .unwrap(),
provider.provide_proto_io().await.unwrap(), &mut verifier_io,
) )
.await .await
.unwrap(); .unwrap();
let (tls_connection, prover_fut) = prover let server_io = provider.provide_server_io().await.unwrap();
.connect(
let (tls_connection, prover) = prover
.setup(
TlsClientConfig::builder() TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap())) .server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.root_store(RootCertStore { .root_store(RootCertStore {
@@ -55,12 +59,10 @@ async fn prover(provider: &IoProvider) {
}) })
.build() .build()
.unwrap(), .unwrap(),
provider.provide_server_io().await.unwrap(),
) )
.await
.unwrap(); .unwrap();
let prover_task = spawn(prover_fut); let prover_task = spawn(prover.run(server_io, verifier_io));
let (mut request_sender, connection) = let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(FuturesIo::new(tls_connection)) hyper::client::conn::http1::handshake(FuturesIo::new(tls_connection))
@@ -87,7 +89,7 @@ async fn prover(provider: &IoProvider) {
let _ = response.into_body().collect().await.unwrap().to_bytes(); let _ = response.into_body().collect().await.unwrap().to_bytes();
let mut prover = prover_task.await.unwrap().unwrap(); let (mut prover, _, mut verifier_io) = prover_task.await.unwrap().unwrap();
let (sent_len, recv_len) = prover.transcript().len(); let (sent_len, recv_len) = prover.transcript().len();
@@ -114,11 +116,13 @@ async fn prover(provider: &IoProvider) {
let config = builder.build().unwrap(); let config = builder.build().unwrap();
prover.prove(&config).await.unwrap(); prover.prove(&config, &mut verifier_io).await.unwrap();
prover.close().await.unwrap(); prover.close(&mut verifier_io).await.unwrap();
} }
async fn verifier(provider: &IoProvider) { async fn verifier(provider: &IoProvider) {
let mut prover_io = provider.provide_proto_io().await.unwrap();
let config = VerifierConfig::builder() let config = VerifierConfig::builder()
.root_store(RootCertStore { .root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
@@ -127,13 +131,13 @@ async fn verifier(provider: &IoProvider) {
.unwrap(); .unwrap();
let verifier = Verifier::new(config) let verifier = Verifier::new(config)
.commit(provider.provide_proto_io().await.unwrap()) .commit(&mut prover_io)
.await .await
.unwrap() .unwrap()
.accept() .accept(&mut prover_io)
.await .await
.unwrap() .unwrap()
.run() .run(&mut prover_io)
.await .await
.unwrap(); .unwrap();
@@ -144,9 +148,15 @@ async fn verifier(provider: &IoProvider) {
.. ..
}, },
verifier, verifier,
) = verifier.verify().await.unwrap().accept().await.unwrap(); ) = verifier
.verify(&mut prover_io)
.await
.unwrap()
.accept(&mut prover_io)
.await
.unwrap();
verifier.close().await.unwrap(); verifier.close(&mut prover_io).await.unwrap();
let ServerName::Dns(server_name) = server_name.unwrap(); let ServerName::Dns(server_name) = server_name.unwrap();

View File

@@ -66,7 +66,6 @@ rand_chacha = { workspace = true }
rstest = { workspace = true } rstest = { workspace = true }
tls-server-fixture = { workspace = true } tls-server-fixture = { workspace = true }
tlsn-tls-client = { workspace = true } tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] } tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }

View File

@@ -378,15 +378,29 @@ impl MpcTlsLeader {
Ok(()) Ok(())
} }
/// Defers decryption of any incoming messages. /// Enables or disables the decryption of any incoming messages.
///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
#[instrument(level = "debug", skip_all, err)] #[instrument(level = "debug", skip_all, err)]
pub async fn defer_decryption(&mut self) -> Result<(), MpcTlsError> { pub fn enable_decryption(&mut self, enable: bool) -> Result<(), MpcTlsError> {
self.is_decrypting = false; self.is_decrypting = enable;
self.notifier.clear();
if enable {
self.notifier.set();
} else {
self.notifier.clear();
}
Ok(()) Ok(())
} }
/// Returns if incoming messages are decrypted.
pub fn is_decrypting(&self) -> bool {
self.is_decrypting
}
/// Stops the actor. /// Stops the actor.
pub fn stop(&mut self, ctx: &mut LudiContext<Self>) { pub fn stop(&mut self, ctx: &mut LudiContext<Self>) {
ctx.stop(); ctx.stop();

View File

@@ -32,10 +32,14 @@ impl MpcTlsLeaderCtrl {
Self { address } Self { address }
} }
/// Defers decryption of any incoming messages. /// Enables or disables the decryption of any incoming messages.
pub async fn defer_decryption(&self) -> Result<(), MpcTlsError> { ///
/// # Arguments
///
/// * `enable` - Whether to enable or disable decryption.
pub async fn enable_decryption(&self, enable: bool) -> Result<(), MpcTlsError> {
self.address self.address
.send(DeferDecryption) .send(EnableDecryption { enable })
.await .await
.map_err(MpcTlsError::actor)? .map_err(MpcTlsError::actor)?
} }
@@ -981,7 +985,7 @@ impl Handler<BackendMsgServerClosed> for MpcTlsLeader {
} }
} }
impl Dispatch<MpcTlsLeader> for DeferDecryption { impl Dispatch<MpcTlsLeader> for EnableDecryption {
fn dispatch<R: FnOnce(Self::Return) + Send>( fn dispatch<R: FnOnce(Self::Return) + Send>(
self, self,
actor: &mut MpcTlsLeader, actor: &mut MpcTlsLeader,
@@ -992,13 +996,13 @@ impl Dispatch<MpcTlsLeader> for DeferDecryption {
} }
} }
impl Handler<DeferDecryption> for MpcTlsLeader { impl Handler<EnableDecryption> for MpcTlsLeader {
async fn handle( async fn handle(
&mut self, &mut self,
_msg: DeferDecryption, msg: EnableDecryption,
_ctx: &mut LudiCtx<Self>, _ctx: &mut LudiCtx<Self>,
) -> <DeferDecryption as Message>::Return { ) -> <EnableDecryption as Message>::Return {
self.defer_decryption().await self.enable_decryption(msg.enable)
} }
} }
@@ -1048,7 +1052,7 @@ pub enum MpcTlsLeaderMsg {
BackendMsgGetNotify(BackendMsgGetNotify), BackendMsgGetNotify(BackendMsgGetNotify),
BackendMsgIsEmpty(BackendMsgIsEmpty), BackendMsgIsEmpty(BackendMsgIsEmpty),
BackendMsgServerClosed(BackendMsgServerClosed), BackendMsgServerClosed(BackendMsgServerClosed),
DeferDecryption(DeferDecryption), DeferDecryption(EnableDecryption),
Stop(Stop), Stop(Stop),
} }
@@ -1083,7 +1087,7 @@ pub enum MpcTlsLeaderMsgReturn {
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return), BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return), BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
BackendMsgServerClosed(<BackendMsgServerClosed as Message>::Return), BackendMsgServerClosed(<BackendMsgServerClosed as Message>::Return),
DeferDecryption(<DeferDecryption as Message>::Return), DeferDecryption(<EnableDecryption as Message>::Return),
Stop(<Stop as Message>::Return), Stop(<Stop as Message>::Return),
} }
@@ -1732,23 +1736,25 @@ impl Wrap<BackendMsgServerClosed> for MpcTlsLeaderMsg {
} }
} }
/// Message to start deferring the decryption. /// Message to enable or disable the decryption of messages.
#[allow(missing_docs)] #[allow(missing_docs)]
#[derive(Debug)] #[derive(Debug)]
pub struct DeferDecryption; pub struct EnableDecryption {
pub enable: bool,
}
impl Message for DeferDecryption { impl Message for EnableDecryption {
type Return = Result<(), MpcTlsError>; type Return = Result<(), MpcTlsError>;
} }
impl From<DeferDecryption> for MpcTlsLeaderMsg { impl From<EnableDecryption> for MpcTlsLeaderMsg {
fn from(value: DeferDecryption) -> Self { fn from(value: EnableDecryption) -> Self {
MpcTlsLeaderMsg::DeferDecryption(value) MpcTlsLeaderMsg::DeferDecryption(value)
} }
} }
impl Wrap<DeferDecryption> for MpcTlsLeaderMsg { impl Wrap<EnableDecryption> for MpcTlsLeaderMsg {
fn unwrap_return(ret: Self::Return) -> Result<<DeferDecryption as Message>::Return, Error> { fn unwrap_return(ret: Self::Return) -> Result<<EnableDecryption as Message>::Return, Error> {
match ret { match ret {
Self::Return::DeferDecryption(value) => Ok(value), Self::Return::DeferDecryption(value) => Ok(value),
_ => Err(Error::Wrapper), _ => Err(Error::Wrapper),

View File

@@ -1,160 +0,0 @@
use std::sync::Arc;
use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context;
use mpz_core::Block;
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta;
use mpz_ot::{
ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, SeedableRng};
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mpc_tls_test() {
tracing_subscriber::fmt::init();
let config = Config::builder()
.defer_decryption(false)
.max_sent(1 << 13)
.max_recv_online(1 << 13)
.max_recv(1 << 13)
.build()
.unwrap();
let (leader, follower) = build_pair(config);
tokio::try_join!(
tokio::spawn(leader_task(leader)),
tokio::spawn(follower_task(follower))
)
.unwrap();
}
async fn leader_task(mut leader: MpcTlsLeader) {
leader.alloc().unwrap();
leader.preprocess().await.unwrap();
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(RootCertStore {
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
})
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();
let client = tls_client::ClientConnection::new(
Arc::new(config),
Box::new(leader_ctrl.clone()),
server_name,
)
.unwrap();
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: keep-alive\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 48];
conn.read_exact(&mut buf).await.unwrap();
leader_ctrl.defer_decryption().await.unwrap();
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: close\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
conn.close().await.unwrap();
let mut buf = vec![0u8; 1024];
conn.read_to_end(&mut buf).await.unwrap();
leader_ctrl.stop().await.unwrap();
handle.await.unwrap();
}
async fn follower_task(mut follower: MpcTlsFollower) {
follower.alloc().unwrap();
follower.preprocess().await.unwrap();
follower.run().await.unwrap();
}
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let mut rng = StdRng::seed_from_u64(0);
let (mut mt_a, mut mt_b) = test_mt_context(8);
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
let delta_a = Delta::new(Block::random(&mut rng));
let delta_b = Delta::new(Block::random(&mut rng));
let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner());
let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner());
let rcot_send_a = SharedRCOTSender::new(rcot_send_a);
let rcot_send_b = SharedRCOTSender::new(rcot_send_b);
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
let leader = MpcTlsLeader::new(
config.clone(),
ctx_a,
mpc_a,
(rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a),
rcot_recv_a,
);
let follower = MpcTlsFollower::new(
config,
ctx_b,
mpc_b,
rcot_send_b,
(rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b),
);
(leader, follower)
}

View File

@@ -1,39 +0,0 @@
[package]
name = "tlsn-tls-client-async"
authors = ["TLSNotary Team"]
description = "An async TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "async"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
workspace = true
[lib]
name = "tls_client_async"
[features]
default = ["tracing"]
tracing = ["dep:tracing"]
[dependencies]
tlsn-tls-client = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
tokio-util = { workspace = true, features = ["io", "compat"] }
tracing = { workspace = true, optional = true }
[dev-dependencies]
tls-server-fixture = { workspace = true }
http-body-util = { workspace = true }
hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
rstest = { workspace = true }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
rustls-webpki = { workspace = true }
rustls-pki-types = { workspace = true }

View File

@@ -1,89 +0,0 @@
use bytes::Bytes;
use futures::{
channel::mpsc::{Receiver, SendError, Sender},
sink::SinkMapErr,
AsyncRead, AsyncWrite, SinkExt,
};
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind},
pin::Pin,
task::{Context, Poll},
};
use tokio_util::{
compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt},
io::{CopyToBytes, SinkWriter, StreamReader},
};
type CompatSinkWriter =
Compat<SinkWriter<CopyToBytes<SinkMapErr<Sender<Bytes>, fn(SendError) -> IoError>>>>;
/// A TLS connection to a server.
///
/// This type implements `AsyncRead` and `AsyncWrite` and can be used to
/// communicate with a server using TLS.
///
/// # Note
///
/// This connection is closed on a best-effort basis if this is dropped. To
/// ensure a clean close, you should call
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
/// connection.
#[derive(Debug)]
pub struct TlsConnection {
/// The data to be transmitted to the server is sent to this sink.
tx_sender: CompatSinkWriter,
/// The data to be received from the server is received from this stream.
rx_receiver: Compat<StreamReader<Receiver<Result<Bytes, IoError>>, Bytes>>,
}
impl TlsConnection {
/// Creates a new TLS connection.
pub(crate) fn new(
tx_sender: Sender<Bytes>,
rx_receiver: Receiver<Result<Bytes, IoError>>,
) -> Self {
fn convert_error(err: SendError) -> IoError {
if err.is_disconnected() {
IoErrorKind::BrokenPipe.into()
} else {
IoErrorKind::WouldBlock.into()
}
}
Self {
tx_sender: SinkWriter::new(CopyToBytes::new(
tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError),
))
.compat_write(),
rx_receiver: StreamReader::new(rx_receiver).compat(),
}
}
}
impl AsyncRead for TlsConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.rx_receiver).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.tx_sender).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.tx_sender).poll_close(cx)
}
}

View File

@@ -1,269 +0,0 @@
//! Provides a TLS client which exposes an async socket.
//!
//! This library provides the [bind_client] function which attaches a TLS client
//! to a socket connection and then exposes a [TlsConnection] object, which
//! provides an async socket API for reading and writing cleartext. The TLS
//! client will then automatically encrypt and decrypt traffic and forward that
//! to the provided socket.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod conn;
use bytes::{Buf, Bytes};
use futures::{
channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite,
AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "tracing")]
use tracing::{debug, debug_span, trace, warn, Instrument};
use tls_client::ClientConnection;
pub use conn::TlsConnection;
const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB
const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB
/// An error that can occur during a TLS connection.
#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error(transparent)]
TlsError(#[from] tls_client::Error),
#[error(transparent)]
IOError(#[from] std::io::Error),
}
/// Closed connection data.
#[derive(Debug)]
pub struct ClosedConnection {
/// The connection for the client
pub client: ClientConnection,
/// Sent plaintext bytes
pub sent: Vec<u8>,
/// Received plaintext bytes
pub recv: Vec<u8>,
}
/// A future which runs the TLS connection to completion.
///
/// This future must be polled in order for the connection to make progress.
#[must_use = "futures do nothing unless polled"]
pub struct ConnectionFuture {
fut: Pin<Box<dyn Future<Output = Result<ClosedConnection, ConnectionError>> + Send>>,
}
impl Future for ConnectionFuture {
type Output = Result<ClosedConnection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.fut.poll_unpin(cx)
}
}
/// Binds a client connection to the provided socket.
///
/// Returns a connection handle and a future which runs the connection to
/// completion.
///
/// # Errors
///
/// Any connection errors that occur will be returned from the future, not
/// [`TlsConnection`].
pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
socket: T,
mut client: ClientConnection,
) -> (TlsConnection, ConnectionFuture) {
let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14);
let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14);
let conn = TlsConnection::new(tx_sender, rx_receiver);
let fut = async move {
client.start().await?;
let mut notify = client.get_notify().await?;
let (mut server_rx, mut server_tx) = socket.split();
let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE];
let mut rx_buf = [0u8; RX_BUF_SIZE];
let mut handshake_done = false;
let mut client_closed = false;
let mut server_closed = false;
let mut sent = Vec::with_capacity(1024);
let mut recv = Vec::with_capacity(1024);
let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
// We don't start writing application data until the handshake is complete.
let mut tx_recv_fut: Fuse<Next<'_, mpsc::Receiver<Bytes>>> = Fuse::terminated();
// Runs both the tx and rx halves of the connection to completion.
// This loop does not terminate until the *SERVER* closes the connection and
// we've processed all received data. If an error occurs, the `TlsConnection`
// channels will be closed and the error will be returned from this future.
'conn: loop {
// Write all pending TLS data to the server.
if client.wants_write() && !client_closed {
#[cfg(feature = "tracing")]
trace!("client wants to write");
while client.wants_write() {
let _sent = client.write_tls_async(&mut server_tx).await?;
#[cfg(feature = "tracing")]
trace!("sent {} tls bytes to server", _sent);
}
server_tx.flush().await?;
}
// Forward received plaintext to `TlsConnection`.
while !client.plaintext_is_empty() {
let read = client.read_plaintext(&mut rx_buf)?;
recv.extend(&rx_buf[..read]);
// Ignore if the receiver has hung up.
_ = rx_sender
.send(Ok(Bytes::copy_from_slice(&rx_buf[..read])))
.await;
#[cfg(feature = "tracing")]
trace!("forwarded {} plaintext bytes to conn", read);
}
if !client.is_handshaking() && !handshake_done {
#[cfg(feature = "tracing")]
debug!("handshake complete");
handshake_done = true;
// Start reading application data that needs to be transmitted from the
// `TlsConnection`.
tx_recv_fut = tx_receiver.next().fuse();
}
if server_closed && client.plaintext_is_empty() && client.is_empty().await? {
break 'conn;
}
select_biased! {
// Reads TLS data from the server and writes it into the client.
received = &mut rx_tls_fut => {
let received = received?;
#[cfg(feature = "tracing")]
trace!("received {} tls bytes from server", received);
// Loop until we've processed all the data we received in this read.
// Note that we must make one iteration even if `received == 0`.
let mut processed = 0;
let mut reader = rx_tls_buf[..received].reader();
loop {
processed += client.read_tls(&mut reader)?;
client.process_new_packets().await?;
debug_assert!(processed <= received);
if processed >= received {
break;
}
}
#[cfg(feature = "tracing")]
trace!("processed {} tls bytes from server", processed);
// By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
// has closed the socket.
if received == 0 {
#[cfg(feature = "tracing")]
debug!("server closed connection");
server_closed = true;
client.server_closed().await?;
// Do not read from the socket again.
rx_tls_fut = Fuse::terminated();
} else {
// Reset the read future so next iteration we can read again.
rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse();
}
}
// If we receive None from `TlsConnection`, it has closed, so we
// send a close_notify to the server.
data = &mut tx_recv_fut => {
if let Some(data) = data {
#[cfg(feature = "tracing")]
trace!("writing {} plaintext bytes to client", data.len());
sent.extend(&data);
client
.write_all_plaintext(&data)
.await?;
tx_recv_fut = tx_receiver.next().fuse();
} else {
if !server_closed {
if let Err(e) = send_close_notify(&mut client, &mut server_tx).await {
#[cfg(feature = "tracing")]
warn!("failed to send close_notify to server: {}", e);
}
}
client_closed = true;
tx_recv_fut = Fuse::terminated();
}
}
// Waits for a notification from the backend that it is ready to decrypt data.
_ = &mut notify => {
#[cfg(feature = "tracing")]
trace!("backend is ready to decrypt");
client.process_new_packets().await?;
}
}
}
#[cfg(feature = "tracing")]
debug!("client shutdown");
_ = server_tx.close().await;
tx_receiver.close();
rx_sender.close_channel();
#[cfg(feature = "tracing")]
trace!(
"server close notify: {}, sent: {}, recv: {}",
client.received_close_notify(),
sent.len(),
recv.len()
);
Ok(ClosedConnection { client, sent, recv })
};
#[cfg(feature = "tracing")]
let fut = fut.instrument(debug_span!("tls_connection"));
let fut = ConnectionFuture { fut: Box::pin(fut) };
(conn, fut)
}
async fn send_close_notify(
client: &mut ClientConnection,
server_tx: &mut (impl AsyncWrite + Unpin),
) -> Result<(), ConnectionError> {
#[cfg(feature = "tracing")]
trace!("sending close_notify to server");
client.send_close_notify().await?;
client.process_new_packets().await?;
// Flush all remaining plaintext
while client.wants_write() {
client.write_tls_async(server_tx).await?;
}
server_tx.flush().await?;
Ok(())
}

View File

@@ -1,438 +0,0 @@
use std::{str, sync::Arc};
use core::future::Future;
use futures::{AsyncReadExt, AsyncWriteExt};
use http_body_util::{BodyExt as _, Full};
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rstest::{fixture, rstest};
use rustls_pki_types::CertificateDer;
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
use tls_server_fixture::{
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
SERVER_DOMAIN,
};
use tokio::task::JoinHandle;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
// An established client TLS connection
struct TlsFixture {
client_tls_conn: TlsConnection,
// a handle that must be `.await`ed to get the result of a TLS connection
closed_tls_task: JoinHandle<Result<ClosedConnection, ConnectionError>>,
}
// Sets up a TLS connection between client and server and sends a hello message
#[fixture]
async fn set_up_tls() -> TlsFixture {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
client_tls_conn
.write_all(&pad("expecting you to send back hello".to_string()))
.await
.unwrap();
// give the server some time to respond
std::thread::sleep(std::time::Duration::from_millis(10));
let mut plaintext = vec![0u8; 320];
let n = client_tls_conn.read(&mut plaintext).await.unwrap();
let s = str::from_utf8(&plaintext[0..n]).unwrap();
assert_eq!(s, "hello");
TlsFixture {
client_tls_conn,
closed_tls_task,
}
}
// Expect the async tls client wrapped in `hyper::client` to make a successful
// request and receive the expected response
#[tokio::test]
async fn test_hyper_ok() {
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(
Arc::new(config),
Box::new(RustCryptoBackend::new()),
ServerName::try_from(SERVER_DOMAIN).unwrap(),
)
.unwrap();
let (conn, tls_fut) = bind_client(client_socket.compat(), client);
let closed_tls_task = tokio::spawn(tls_fut);
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(TokioIo::new(conn.compat()))
.await
.unwrap();
tokio::spawn(connection);
let request = Request::builder()
.uri(format!("https://{SERVER_DOMAIN}/echo"))
.header("Host", SERVER_DOMAIN)
.header("Connection", "close")
.method("POST")
.body(Full::<Bytes>::new("hello".into()))
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
assert!(response.status() == StatusCode::OK);
// Process the response body
response.into_body().collect().await.unwrap().to_bytes();
let _ = server_task.await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server responds to the client's
// close_notify AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_socket_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us AND close the socket
// after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// but doesn't close the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect a clean TLS connection closure when server sends close_notify first
// AND also closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_close_notify_and_socket_close(
set_up_tls: impl Future<Output = TlsFixture>,
) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send close_notify back to us after 10 ms
client_tls_conn
.write_all(&pad("send_close_notify_and_close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time for server's close_notify to arrive
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(closed_conn.client.received_close_notify());
}
// Expect to be able to read the data after server closes the socket abruptly
#[rstest]
#[tokio::test]
async fn test_ok_read_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to send us a hello message
client_tls_conn
.write_all(&pad("send a hello message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to read some more data
let mut buf = vec![0u8; 10];
let n = client_tls_conn.read(&mut buf).await.unwrap();
assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello");
}
// Expect there to be no error when server DOES NOT send close_notify but just
// closes the socket
#[rstest]
#[tokio::test]
async fn test_ok_server_no_close_notify(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
let closed_conn = closed_tls_task.await.unwrap().unwrap();
assert!(!closed_conn.client.received_close_notify());
}
// Expect to register a delay when the server delays closing the socket
#[rstest]
#[tokio::test]
async fn test_ok_delay_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
client_tls_conn
.write_all(&pad("must_delay_when_closing".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// closing `client_tls_conn` will cause close_notify to be sent by the client
client_tls_conn.close().await.unwrap();
use std::time::Instant;
let now = Instant::now();
// this will resolve when the server stops delaying closing the socket
let closed_conn = closed_tls_task.await.unwrap().unwrap();
let elapsed = now.elapsed();
// the elapsed time must be roughly equal to the server's delay
// (give or take timing variations)
assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50);
assert!(!closed_conn.client.received_close_notify());
}
// Expect client to error when server sends a corrupted message
#[rstest]
#[tokio::test]
async fn test_err_corrupted(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send a corrupted message
client_tls_conn
.write_all(&pad("send_corrupted_message".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received corrupt message"
);
}
// Expect client to error when server sends a TLS record with a bad MAC
#[rstest]
#[tokio::test]
async fn test_err_bad_mac(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_record_with_bad_mac".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"backend error: Decryption error: \"aead::Error\""
);
}
// Expect client to error when server sends a fatal alert
#[rstest]
#[tokio::test]
async fn test_err_alert(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
closed_tls_task,
} = set_up_tls.await;
// instruct the server to send us a TLS record with a bad MAC
client_tls_conn
.write_all(&pad("send_alert".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
client_tls_conn.close().await.unwrap();
assert_eq!(
closed_tls_task.await.unwrap().err().unwrap().to_string(),
"received fatal alert: BadRecordMac"
);
}
// Expect an error when trying to write data to a connection which server closed
// abruptly
#[rstest]
#[tokio::test]
async fn test_err_write_after_close(set_up_tls: impl Future<Output = TlsFixture>) {
let TlsFixture {
mut client_tls_conn,
..
} = set_up_tls.await;
// instruct the server to close the socket
client_tls_conn
.write_all(&pad("close_socket".to_string()))
.await
.unwrap();
client_tls_conn.flush().await.unwrap();
// give enough time to close the socket
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
// try to send some more data
let res = client_tls_conn
.write_all(&pad("more data".to_string()))
.await;
assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe);
}
// Converts a string into a slice zero-padded to APP_RECORD_LENGTH
fn pad(s: String) -> Vec<u8> {
assert!(s.len() <= APP_RECORD_LENGTH);
let mut buf = vec![0u8; APP_RECORD_LENGTH];
buf[..s.len()].copy_from_slice(s.as_bytes());
buf
}

View File

@@ -457,6 +457,9 @@ impl ConnectionCommon {
return Err(Error::CorruptMessage); return Err(Error::CorruptMessage);
} }
// Process outgoing plaintext buffer and encrypt messages.
self.flush_plaintext().await?;
// Process new messages. // Process new messages.
while let Some(msg) = self.message_deframer.frames.pop_front() { while let Some(msg) = self.message_deframer.frames.pop_front() {
// If we're not decrypting yet, we process it immediately. Otherwise it will be // If we're not decrypting yet, we process it immediately. Otherwise it will be
@@ -508,25 +511,22 @@ impl ConnectionCommon {
Ok(state) Ok(state)
} }
/// Write buffer into connection. /// Writes plaintext `buf` into an internal buffer. May not fully process the
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> { /// whole buffer and returns the processed length.
if let Ok(st) = &mut self.state { pub fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
st.perhaps_write_key_update(&mut self.common_state).await; if buf.is_empty() {
// Don't send empty fragments.
return Ok(0);
} }
self.common_state.send_some_plaintext(buf).await
let len = self.sendable_plaintext.append_limited_copy(buf);
Ok(len)
} }
/// Write entire buffer into connection. /// Writes the entire plaintext `buf` into an internal buffer.
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> { pub fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<(), Error> {
let mut pos = 0; self.sendable_plaintext.append(buf.to_vec());
while pos < buf.len() { Ok(())
pos += self.write_plaintext(&buf[pos..]).await?;
}
self.backend.flush().await?;
while let Some(msg) = self.backend.next_outgoing().await? {
self.queue_tls_message(msg);
}
Ok(pos)
} }
/// Read TLS content from `rd`. This method does internal /// Read TLS content from `rd`. This method does internal
@@ -690,6 +690,11 @@ impl CommonState {
self.received_plaintext.is_empty() self.received_plaintext.is_empty()
} }
/// Returns true if the buffer for sendable plaintext is full.
pub fn sendable_plaintext_is_full(&self) -> bool {
self.sendable_plaintext.is_full()
}
/// Returns true if the connection is currently performing the TLS /// Returns true if the connection is currently performing the TLS
/// handshake. /// handshake.
/// ///
@@ -782,15 +787,6 @@ impl CommonState {
} }
} }
/// Send plaintext application data, fragmenting and
/// encrypting it as it goes out.
///
/// If internal buffers are too small, this function will not accept
/// all the data.
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
self.send_plain(data, Limit::Yes).await
}
// Changing the keys must not span any fragmented handshake // Changing the keys must not span any fragmented handshake
// messages. Otherwise the defragmented messages will have // messages. Otherwise the defragmented messages will have
// been protected with two different record layer protections, // been protected with two different record layer protections,
@@ -931,32 +927,6 @@ impl CommonState {
self.sendable_tls.write_to_async(wr).await self.sendable_tls.write_to_async(wr).await
} }
/// Encrypt and send some plaintext `data`. `limit` controls
/// whether the per-connection buffer limits apply.
///
/// Returns the number of bytes written from `data`: this might
/// be less than `data.len()` if buffer limits were exceeded.
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
if !self.may_send_application_data {
// If we haven't completed handshaking, buffer
// plaintext to send once we do.
let len = match limit {
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
Limit::No => self.sendable_plaintext.append(data.to_vec()),
};
return Ok(len);
}
debug_assert!(self.record_layer.is_encrypting());
if data.is_empty() {
// Don't send empty fragments.
return Ok(0);
}
self.send_appdata_encrypt(data, limit).await
}
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> { pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
self.may_send_application_data = true; self.may_send_application_data = true;
self.flush_plaintext().await self.flush_plaintext().await
@@ -1012,15 +982,14 @@ impl CommonState {
self.sendable_tls.set_limit(limit); self.sendable_tls.set_limit(limit);
} }
/// Send any buffered plaintext. Plaintext is buffered if /// Send and encrypt any buffered plaintext. Does nothing during handshake.
/// written during handshake. pub async fn flush_plaintext(&mut self) -> Result<(), Error> {
async fn flush_plaintext(&mut self) -> Result<(), Error> {
if !self.may_send_application_data { if !self.may_send_application_data {
return Ok(()); return Ok(());
} }
while let Some(buf) = self.sendable_plaintext.pop() { while let Some(buf) = self.sendable_plaintext.pop() {
self.send_plain(&buf, Limit::No).await?; self.send_appdata_encrypt(&buf, Limit::No).await?;
} }
Ok(()) Ok(())

View File

@@ -35,6 +35,15 @@ impl ChunkVecBuffer {
self.chunks.is_empty() self.chunks.is_empty()
} }
/// If the buffer has reached limit.
pub(crate) fn is_full(&self) -> bool {
if let Some(limit) = self.limit {
self.len() >= limit
} else {
false
}
}
/// How many bytes we're storing /// How many bytes we're storing
pub(crate) fn len(&self) -> usize { pub(crate) fn len(&self) -> usize {
let mut len = 0; let mut len = 0;

View File

@@ -247,7 +247,8 @@ async fn servered_client_data_sent() {
let (mut client, mut server) = let (mut client, mut server) =
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await; make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(5, client.write_plaintext(b"hello").await.unwrap()); assert_eq!(5, client.write_plaintext(b"hello").unwrap());
client.flush_plaintext().await.unwrap();
do_handshake(&mut client, &mut server).await; do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server); send(&mut client, &mut server);
@@ -286,7 +287,7 @@ async fn servered_both_data_sent() {
make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await; make_pair_for_arc_configs(&Arc::new(client_config), &server_config).await;
assert_eq!(12, server.writer().write(b"from-server!").unwrap()); assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap()); assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
do_handshake(&mut client, &mut server).await; do_handshake(&mut client, &mut server).await;
@@ -432,7 +433,7 @@ async fn server_close_notify() {
// check that alerts don't overtake appdata // check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap()); assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap()); assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
server.send_close_notify(); server.send_close_notify();
receive(&mut server, &mut client); receive(&mut server, &mut client);
@@ -460,7 +461,8 @@ async fn client_close_notify() {
// check that alerts don't overtake appdata // check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap()); assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap()); assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.flush_plaintext().await.unwrap();
client.send_close_notify().await.unwrap(); client.send_close_notify().await.unwrap();
send(&mut client, &mut server); send(&mut client, &mut server);
@@ -487,7 +489,7 @@ async fn server_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata // check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap()); assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap()); assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
receive(&mut server, &mut client); receive(&mut server, &mut client);
transfer_eof(&mut client); transfer_eof(&mut client);
@@ -518,7 +520,7 @@ async fn client_closes_uncleanly() {
// check that unclean EOF reporting does not overtake appdata // check that unclean EOF reporting does not overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap()); assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap()); assert_eq!(12, client.write_plaintext(b"from-client!").unwrap());
client.process_new_packets().await.unwrap(); client.process_new_packets().await.unwrap();
send(&mut client, &mut server); send(&mut client, &mut server);
@@ -900,20 +902,9 @@ async fn client_respects_buffer_limit_pre_handshake() {
client.set_buffer_limit(Some(32)); client.set_buffer_limit(Some(32));
assert_eq!( assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
client assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 12);
.write_plaintext(b"01234567890123456789") client.flush_plaintext().await.unwrap();
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
12
);
do_handshake(&mut client, &mut server).await; do_handshake(&mut client, &mut server).await;
send(&mut client, &mut server); send(&mut client, &mut server);
@@ -953,20 +944,9 @@ async fn client_respects_buffer_limit_post_handshake() {
do_handshake(&mut client, &mut server).await; do_handshake(&mut client, &mut server).await;
client.set_buffer_limit(Some(48)); client.set_buffer_limit(Some(48));
assert_eq!( assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 20);
client assert_eq!(client.write_plaintext(b"01234567890123456789").unwrap(), 6);
.write_plaintext(b"01234567890123456789") client.flush_plaintext().await.unwrap();
.await
.unwrap(),
20
);
assert_eq!(
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap(),
6
);
send(&mut client, &mut server); send(&mut client, &mut server);
server.process_new_packets().unwrap(); server.process_new_packets().unwrap();
@@ -1211,14 +1191,8 @@ async fn client_complete_io_for_write() {
do_handshake(&mut client, &mut server).await; do_handshake(&mut client, &mut server).await;
client client.write_plaintext(b"01234567890123456789").unwrap();
.write_plaintext(b"01234567890123456789") client.write_plaintext(b"01234567890123456789").unwrap();
.await
.unwrap();
client
.write_plaintext(b"01234567890123456789")
.await
.unwrap();
{ {
let mut pipe = ServerSession::new(&mut server); let mut pipe = ServerSession::new(&mut server);
let (rdlen, wrlen) = client let (rdlen, wrlen) = client
@@ -1350,7 +1324,8 @@ async fn server_stream_read() {
for kt in ALL_KEY_TYPES.iter() { for kt in ALL_KEY_TYPES.iter() {
let (mut client, mut server) = make_pair(*kt).await; let (mut client, mut server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").await.unwrap(); client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
{ {
let mut pipe = ClientSession::new(&mut client); let mut pipe = ClientSession::new(&mut client);
@@ -1366,7 +1341,8 @@ async fn server_streamowned_read() {
for kt in ALL_KEY_TYPES.iter() { for kt in ALL_KEY_TYPES.iter() {
let (mut client, server) = make_pair(*kt).await; let (mut client, server) = make_pair(*kt).await;
client.write_all_plaintext(b"world").await.unwrap(); client.write_all_plaintext(b"world").unwrap();
client.process_new_packets().await.unwrap();
{ {
let pipe = ClientSession::new(&mut client); let pipe = ClientSession::new(&mut client);
@@ -1385,7 +1361,9 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted, // errkind: io::ErrorKind::ConnectionAborted,
// after: 0, // after: 0,
// }; // };
// client.write_all_plaintext(b"hello").await.unwrap(); // client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// let mut client_stream = Stream::new(&mut client, &mut pipe); // let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world"); // let rc = client_stream.write(b"world");
// assert!(rc.is_err()); // assert!(rc.is_err());
@@ -1402,7 +1380,9 @@ async fn server_streamowned_read() {
// errkind: io::ErrorKind::ConnectionAborted, // errkind: io::ErrorKind::ConnectionAborted,
// after: 1, // after: 1,
// }; // };
// client.write_all_plaintext(b"hello").await.unwrap(); // client.write_all_plaintext(b"hello").unwrap();
// client.process_new_packets().await.unwrap();
//
// let mut client_stream = Stream::new(&mut client, &mut pipe); // let mut client_stream = Stream::new(&mut client, &mut pipe);
// let rc = client_stream.write(b"world"); // let rc = client_stream.write(b"world");
// assert_eq!(format!("{:?}", rc), "Ok(5)"); // assert_eq!(format!("{:?}", rc), "Ok(5)");
@@ -1900,14 +1880,9 @@ async fn servered_write_for_client_appdata() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await; let (mut client, mut server) = make_pair(KeyType::Rsa).await;
do_handshake(&mut client, &mut server).await; do_handshake(&mut client, &mut server).await;
client client.write_all_plaintext(b"01234567890123456789").unwrap();
.write_all_plaintext(b"01234567890123456789") client.write_all_plaintext(b"01234567890123456789").unwrap();
.await client.process_new_packets().await.unwrap();
.unwrap();
client
.write_all_plaintext(b"01234567890123456789")
.await
.unwrap();
{ {
let mut pipe = ServerSession::new(&mut server); let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap(); let wrlen = client.write_tls(&mut pipe).unwrap();
@@ -2019,11 +1994,10 @@ async fn servered_write_for_server_handshake_no_half_rtt_by_default() {
async fn servered_write_for_client_handshake() { async fn servered_write_for_client_handshake() {
let (mut client, mut server) = make_pair(KeyType::Rsa).await; let (mut client, mut server) = make_pair(KeyType::Rsa).await;
client client.write_all_plaintext(b"01234567890123456789").unwrap();
.write_all_plaintext(b"01234567890123456789") client.write_all_plaintext(b"0123456789").unwrap();
.await client.process_new_packets().await.unwrap();
.unwrap();
client.write_all_plaintext(b"0123456789").await.unwrap();
{ {
let mut pipe = ServerSession::new(&mut server); let mut pipe = ServerSession::new(&mut server);
let wrlen = client.write_tls(&mut pipe).unwrap(); let wrlen = client.write_tls(&mut pipe).unwrap();

View File

@@ -21,11 +21,11 @@ tlsn-attestation = { workspace = true }
tlsn-core = { workspace = true } tlsn-core = { workspace = true }
tlsn-deap = { workspace = true } tlsn-deap = { workspace = true }
tlsn-tls-client = { workspace = true } tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true } tlsn-tls-core = { workspace = true }
tlsn-mpc-tls = { workspace = true } tlsn-mpc-tls = { workspace = true }
tlsn-cipher = { workspace = true } tlsn-cipher = { workspace = true }
futures-plex = { workspace = true }
serio = { workspace = true, features = ["compat"] } serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] } uid-mux = { workspace = true, features = ["serio"] }
web-spawn = { workspace = true, optional = true } web-spawn = { workspace = true, optional = true }
@@ -57,6 +57,7 @@ serde = { workspace = true, features = ["derive"] }
ghash = { workspace = true } ghash = { workspace = true }
semver = { workspace = true, features = ["serde"] } semver = { workspace = true, features = ["serde"] }
once_cell = { workspace = true } once_cell = { workspace = true }
pin-project-lite = { workspace = true }
rangeset = { workspace = true } rangeset = { workspace = true }
webpki-roots = { workspace = true } webpki-roots = { workspace = true }

View File

@@ -1,21 +0,0 @@
//! Execution context.
use mpz_common::context::Multithread;
use crate::mux::MuxControl;
/// Maximum concurrency for multi-threaded context.
pub(crate) const MAX_CONCURRENCY: usize = 8;
/// Builds a multi-threaded context with the given muxer.
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
#[cfg(all(feature = "web", target_arch = "wasm32"))]
let builder = builder.spawn_handler(|f| {
let _ = web_spawn::spawn(f);
Ok(())
});
builder.build().unwrap()
}

View File

@@ -4,7 +4,6 @@
#![deny(clippy::all)] #![deny(clippy::all)]
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
pub(crate) mod context;
pub(crate) mod ghash; pub(crate) mod ghash;
pub(crate) mod map; pub(crate) mod map;
pub(crate) mod mpz; pub(crate) mod mpz;
@@ -13,6 +12,7 @@ pub(crate) mod mux;
pub mod prover; pub mod prover;
pub(crate) mod tag; pub(crate) mod tag;
pub(crate) mod transcript_internal; pub(crate) mod transcript_internal;
pub(crate) mod utils;
pub mod verifier; pub mod verifier;
pub use tlsn_attestation as attestation; pub use tlsn_attestation as attestation;
@@ -27,6 +27,8 @@ pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver") Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
}); });
const BUF_CAP: usize = 16 * 1024;
/// The party's role in the TLSN protocol. /// The party's role in the TLSN protocol.
/// ///
/// A Notary is classified as a Verifier. /// A Notary is classified as a Verifier.

View File

@@ -1,31 +1,38 @@
//! Prover. //! Prover.
mod client;
mod conn;
mod control;
mod error; mod error;
mod future;
mod prove; mod prove;
pub mod state; pub mod state;
pub use conn::TlsConnection;
pub use control::ProverControl;
pub use error::ProverError; pub use error::ProverError;
pub use future::ProverFuture;
pub use tlsn_core::ProverOutput; pub use tlsn_core::ProverOutput;
use crate::{ use crate::{
Role, BUF_CAP, Role,
context::build_mt_context,
mpz::{ProverDeps, build_prover_deps, translate_keys}, mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg}, msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux, mux::attach_mux,
tag::verify_tags, prover::{
client::{MpcTlsClient, TlsOutput},
state::ConnectedProj,
},
utils::{CopyIo, await_with_copy_io, build_mt_context},
}; };
use futures::{AsyncRead, AsyncWrite, TryFutureExt}; use futures::{AsyncRead, AsyncReadExt, AsyncWrite, FutureExt, TryFutureExt, ready};
use mpc_tls::LeaderCtrl;
use mpz_vm_core::prelude::*;
use rustls_pki_types::CertificateDer; use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt}; use serio::{SinkExt, stream::IoStreamExt};
use std::sync::Arc; use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tls_client::{ClientConnection, ServerName as TlsServerName}; use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{ use tlsn_core::{
config::{ config::{
prove::ProveConfig, prove::ProveConfig,
@@ -36,10 +43,9 @@ use tlsn_core::{
connection::{HandshakeData, ServerName}, connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript}, transcript::{TlsTranscript, Transcript},
}; };
use tracing::{Span, debug, info_span, instrument};
use webpki::anchor_from_trusted_cert; use webpki::anchor_from_trusted_cert;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
/// A prover instance. /// A prover instance.
#[derive(Debug)] #[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> { pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -71,14 +77,27 @@ impl Prover<state::Initialized> {
/// # Arguments /// # Arguments
/// ///
/// * `config` - The TLS commitment configuration. /// * `config` - The TLS commitment configuration.
/// * `socket` - The socket to the TLS verifier. /// * `verifier_io` - The IO to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)] pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin>(
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self, self,
config: TlsCommitConfig, config: TlsCommitConfig,
socket: S, verifier_io: S,
) -> Result<Prover<state::CommitAccepted>, ProverError> { ) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover); let (duplex_a, mut duplex_b) = futures_plex::duplex(BUF_CAP);
let fut = Box::pin(self.commit_inner(config, duplex_a).fuse());
let mut prover = await_with_copy_io(fut, verifier_io, &mut duplex_b).await?;
prover.state.verifier_io = Some(duplex_b);
Ok(prover)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn commit_inner<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
config: TlsCommitConfig,
verifier_io: S,
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(verifier_io, Role::Prover);
let mut mt = build_mt_context(mux_ctrl.clone()); let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?; let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
@@ -122,6 +141,7 @@ impl Prover<state::Initialized> {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::CommitAccepted { state: state::CommitAccepted {
verifier_io: None,
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
mpc_tls, mpc_tls,
@@ -133,31 +153,29 @@ impl Prover<state::Initialized> {
} }
impl Prover<state::CommitAccepted> { impl Prover<state::CommitAccepted> {
/// Connects to the server using the provided socket. /// Sets up the prover with the client configuration.
/// ///
/// Returns a handle to the TLS connection, a future which returns the /// Returns a set up prover, and a [`TlsConnection`] which can be used to
/// prover once the connection is closed and the TLS transcript is /// read and write bytes from/to the server.
/// committed.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `config` - The TLS client configuration. /// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)] #[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>( pub fn setup(
self, self,
config: TlsClientConfig, config: TlsClientConfig,
socket: S, ) -> Result<(TlsConnection, Prover<state::Setup>), ProverError> {
) -> Result<(TlsConnection, ProverFuture), ProverError> {
let state::CommitAccepted { let state::CommitAccepted {
verifier_io,
mux_ctrl, mux_ctrl,
mut mux_fut, mux_fut,
mpc_tls, mpc_tls,
keys, keys,
vm, vm,
..
} = self.state; } = self.state;
let decrypt = mpc_tls.is_decrypting();
let (mpc_ctrl, mpc_fut) = mpc_tls.run(); let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = config.server_name(); let ServerName::Dns(server_name) = config.server_name();
@@ -202,95 +220,296 @@ impl Prover<state::CommitAccepted> {
) )
.map_err(ProverError::config)?; .map_err(ProverError::config)?;
let (conn, conn_fut) = bind_client(socket, client); let span = self.span.clone();
let fut = Box::pin({ let mpc_tls = MpcTlsClient::new(
let span = self.span.clone(); Box::new(mpc_fut.map_err(ProverError::from)),
let mpc_ctrl = mpc_ctrl.clone(); keys,
async move { vm,
let conn_fut = async { span,
mux_fut mpc_ctrl,
.poll_with(conn_fut.map_err(ProverError::from)) client,
.await?; decrypt,
);
mpc_ctrl.stop().await?; let (duplex_a, duplex_b) = futures_plex::duplex(BUF_CAP);
let prover = Prover {
Ok::<_, ProverError>(()) config: self.config,
}; span: self.span,
state: state::Setup {
info!("starting MPC-TLS"); mux_ctrl,
mux_fut,
let (_, (mut ctx, tls_transcript)) = futures::try_join!( server_name: config.server_name().clone(),
conn_fut, tls_client: Box::new(mpc_tls),
mpc_fut.in_current_span().map_err(ProverError::from) client_io: duplex_a,
)?; verifier_io,
info!("finished MPC-TLS");
{
let mut vm = vm.try_lock().expect("VM should not be locked");
debug!("finalizing mpc");
// Finalize DEAP.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
debug!("mpc finalized");
}
// Pull out ZK VM.
let (_, mut vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut vm,
(keys.server_write_key, keys.server_write_iv),
keys.server_write_mac_key,
*tls_transcript.version(),
tls_transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
Ok(Prover {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
tls_transcript,
transcript,
},
})
}
.instrument(span)
});
Ok((
conn,
ProverFuture {
fut,
ctrl: ProverControl { mpc_ctrl },
}, },
)) };
let conn = TlsConnection::new(duplex_b);
Ok((conn, prover))
}
}
impl Prover<state::Setup> {
/// Returns a handle to control the prover.
pub fn handle(&self) -> ProverControl {
let handle = self.state.tls_client.handle();
ProverControl { handle }
}
/// Attaches IO to the prover.
///
/// # Arguments
///
/// * `server_io` - The IO to the server.
/// * `verifier_io` - The IO to the TLS verifier.
pub fn connect<S, T>(self, server_io: S, verifier_io: T) -> Prover<state::Connected<S, T>>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
T: AsyncRead + AsyncWrite + Send + Unpin,
{
let (client_to_server, server_to_client) = futures_plex::duplex(BUF_CAP);
Prover {
config: self.config,
span: self.span,
state: state::Connected {
verifier_io: self.state.verifier_io,
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
server_name: self.state.server_name,
tls_client: self.state.tls_client,
client_io: self.state.client_io,
output: None,
server_socket: server_io,
verifier_socket: verifier_io,
tls_client_to_server_buf: client_to_server,
server_to_tls_client_buf: server_to_client,
client_closed: false,
server_closed: false,
},
}
}
/// This is a convenience method which attaches IO, runs the prover and
/// returns a committed prover together with the IO.
///
/// # Arguments
///
/// * `server_io` - The IO to the server.
/// * `verifier_io` - The IO to the TLS verifier.
pub async fn run<S, T>(
self,
mut server_io: S,
mut verifier_io: T,
) -> Result<(Prover<state::Committed>, S, T), ProverError>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let mut prover = self.connect(&mut server_io, &mut verifier_io);
(&mut prover).await?;
let prover = prover.finish()?;
Ok((prover, server_io, verifier_io))
}
}
impl<S, T> Future for Prover<state::Connected<S, T>>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
T: AsyncRead + AsyncWrite + Send + Unpin,
{
type Output = Result<(), ProverError>;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = Pin::new(&mut self.state).project();
loop {
let mut progress = false;
if state.output.is_none()
&& let Poll::Ready(output) = state.tls_client.poll(cx)?
{
*state.output = Some(output);
}
progress |= Self::io_client_conn(&mut state, cx)?;
progress |= Self::io_client_server(&mut state, cx)?;
progress |= Self::io_client_verifier(&mut state, cx)?;
_ = state.mux_fut.poll_unpin(cx)?;
if *state.server_closed && state.output.is_some() {
ready!(state.client_io.poll_close(cx))?;
ready!(state.server_socket.poll_close(cx))?;
return Poll::Ready(Ok(()));
} else if !progress {
return Poll::Pending;
}
}
}
}
impl<S, T> Prover<state::Connected<S, T>>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
T: AsyncRead + AsyncWrite + Send + Unpin,
{
fn io_client_conn(
state: &mut ConnectedProj<S, T>,
cx: &mut Context,
) -> Result<bool, ProverError> {
let mut progress = false;
// tls_conn -> tls_client
if state.tls_client.wants_write()
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_read(cx)
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
{
if !buf.is_empty() {
let write = state.tls_client.write(buf)?;
if write > 0 {
progress = true;
simplex.advance(write);
}
} else if !*state.client_closed && !*state.server_closed {
progress = true;
*state.client_closed = true;
state.tls_client.client_close()?;
}
}
// tls_client -> tls_conn
if state.tls_client.wants_read()
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_write(cx)
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
&& let read = state.tls_client.read(buf)?
&& read > 0
{
progress = true;
simplex.advance_mut(read);
}
Ok(progress)
}
fn io_client_server(
state: &mut ConnectedProj<S, T>,
cx: &mut Context,
) -> Result<bool, ProverError> {
let mut progress = false;
// server_socket -> buf
if let Poll::Ready(write) = state
.server_to_tls_client_buf
.poll_write_from(cx, state.server_socket.as_mut())?
{
if write > 0 {
progress = true;
} else if !*state.server_closed {
progress = true;
*state.server_closed = true;
state.tls_client.server_close()?;
}
}
// buf -> tls_client
if state.tls_client.wants_read_tls()
&& let Poll::Ready(mut simplex) =
state.tls_client_to_server_buf.as_mut().poll_lock_read(cx)
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
&& let read = state.tls_client.read_tls(buf)?
&& read > 0
{
progress = true;
simplex.advance(read);
}
// tls_client -> buf
if state.tls_client.wants_write_tls()
&& let Poll::Ready(mut simplex) =
state.tls_client_to_server_buf.as_mut().poll_lock_write(cx)
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
&& let write = state.tls_client.write_tls(buf)?
&& write > 0
{
progress = true;
simplex.advance_mut(write);
}
// buf -> server_socket
if let Poll::Ready(read) = state
.server_to_tls_client_buf
.poll_read_to(cx, state.server_socket.as_mut())?
&& read > 0
{
progress = true;
}
Ok(progress)
}
fn io_client_verifier(
state: &mut ConnectedProj<S, T>,
cx: &mut Context,
) -> Result<bool, ProverError> {
let mut progress = false;
let verifier_io = Pin::new(
(*state.verifier_io)
.as_mut()
.expect("verifier io should be available"),
);
// mux -> verifier_socket
if let Poll::Ready(read) = verifier_io.poll_read_to(cx, state.verifier_socket.as_mut())?
&& read > 0
{
progress = true;
}
// verifier_socket -> mux
if let Poll::Ready(write) =
verifier_io.poll_write_from(cx, state.verifier_socket.as_mut())?
&& write > 0
{
progress = true;
}
Ok(progress)
}
/// Returns a committed prover after the TLS session has completed.
pub fn finish(self) -> Result<Prover<state::Committed>, ProverError> {
let TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
} = self.state.output.ok_or(ProverError::state(
"prover has not yet closed the connection",
))?;
let prover = Prover {
config: self.config,
span: self.span,
state: state::Committed {
verifier_io: self.state.verifier_io,
mux_ctrl: self.state.mux_ctrl,
mux_fut: self.state.mux_fut,
ctx,
vm,
server_name: self.state.server_name,
keys,
tls_transcript,
transcript,
},
};
Ok(prover)
} }
} }
@@ -310,8 +529,30 @@ impl Prover<state::Committed> {
/// # Arguments /// # Arguments
/// ///
/// * `config` - The disclosure configuration. /// * `config` - The disclosure configuration.
/// * `verifier_io` - The IO to the TLS verifier.
pub async fn prove<S>(
&mut self,
config: &ProveConfig,
verifier_io: S,
) -> Result<ProverOutput, ProverError>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let mut duplex = self
.state
.verifier_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.prove_inner(config).fuse());
let output = await_with_copy_io(fut, verifier_io, &mut duplex).await?;
self.state.verifier_io = Some(duplex);
Ok(output)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> { async fn prove_inner(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
let state::Committed { let state::Committed {
mux_fut, mux_fut,
ctx, ctx,
@@ -364,44 +605,31 @@ impl Prover<state::Committed> {
} }
/// Closes the connection with the verifier. /// Closes the connection with the verifier.
///
/// # Arguments
///
/// * `verifier_io` - The IO to the TLS verifier.
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), ProverError> { pub async fn close<S>(mut self, mut verifier_io: S) -> Result<(), ProverError>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let state::Committed { let state::Committed {
mux_ctrl, mux_fut, .. mux_ctrl, mux_fut, ..
} = self.state; } = self.state;
// Wait for the verifier to correctly close the connection. let mut duplex = self
if !mux_fut.is_complete() { .state
mux_ctrl.close(); .verifier_io
mux_fut.await?; .take()
} .expect("duplex should be available");
mux_ctrl.close();
let copy = CopyIo::new(&mut verifier_io, &mut duplex).map_err(ProverError::from);
futures::try_join!(mux_fut.map_err(ProverError::from), copy)?;
// Wait for the verifier to finish closing.
verifier_io.read_exact(&mut [0_u8; 5]).await?;
Ok(()) Ok(())
} }
} }
/// A controller for the prover.
#[derive(Clone)]
pub struct ProverControl {
mpc_ctrl: LeaderCtrl,
}
impl ProverControl {
/// Defers decryption of data from the server until the server has closed
/// the connection.
///
/// This is a performance optimization which will significantly reduce the
/// amount of upload bandwidth used by the prover.
///
/// # Notes
///
/// * The prover may need to close the connection to the server in order for
/// it to close the connection on its end. If neither the prover or server
/// close the connection this will cause a deadlock.
pub async fn defer_decryption(&self) -> Result<(), ProverError> {
self.mpc_ctrl
.defer_decryption()
.await
.map_err(ProverError::from)
}
}

View File

@@ -0,0 +1,93 @@
//! Provides a TLS client.
use crate::{mpz::ProverZk, prover::control::ControlError};
use mpc_tls::SessionKeys;
use std::{
sync::mpsc::{Sender, SyncSender, sync_channel},
task::{Context, Poll},
};
use tlsn_core::transcript::{TlsTranscript, Transcript};
mod mpc;
pub(crate) use mpc::MpcTlsClient;
/// TLS client for MPC and proxy-based TLS implementations.
pub(crate) trait TlsClient {
type Error: std::error::Error + Send + Sync + Unpin + 'static;
/// Returns `true` if the client wants to read TLS data from the server.
fn wants_read_tls(&self) -> bool;
/// Returns `true` if the client wants to write TLS data to the server.
fn wants_write_tls(&self) -> bool;
/// Reads TLS data from the server.
fn read_tls(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Writes TLS data for the server into the provided buffer.
fn write_tls(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Returns `true` if the client wants to read plaintext data.
fn wants_read(&self) -> bool;
/// Returns `true` if the client wants to write plaintext data.
fn wants_write(&self) -> bool;
/// Reads plaintext data from the server into the provided buffer.
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
/// Writes plaintext data to be sent to the server.
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
/// Client closes the connection.
fn client_close(&mut self) -> Result<(), Self::Error>;
/// Server closes the connection.
fn server_close(&mut self) -> Result<(), Self::Error>;
/// Returns a handle to control the client.
fn handle(&self) -> ClientHandle;
/// Polls the client to make progress.
fn poll(&mut self, cx: &mut Context) -> Poll<Result<TlsOutput, Self::Error>>;
}
#[derive(Clone, Debug)]
pub(crate) struct ClientHandle {
sender: Sender<Command>,
}
#[derive(Clone, Debug)]
pub(crate) enum Command {
IsDecrypting(SyncSender<bool>),
SetDecrypt(bool),
ClientClose,
ServerClose,
}
impl ClientHandle {
pub(crate) fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
self.sender
.send(Command::SetDecrypt(enable))
.map_err(|_| ControlError)
}
pub(crate) fn is_decrypting(&self) -> bool {
let (sender, receiver) = sync_channel(1);
let Ok(_) = self.sender.send(Command::IsDecrypting(sender)) else {
return false;
};
receiver.recv().unwrap_or(false)
}
}
/// Output of a TLS session.
pub(crate) struct TlsOutput {
pub(crate) ctx: mpz_common::Context,
pub(crate) vm: ProverZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
}

View File

@@ -0,0 +1,503 @@
//! Implementation of an MPC-TLS client.
use crate::{
mpz::{ProverMpc, ProverZk},
prover::{
ProverError,
client::{ClientHandle, Command, TlsClient, TlsOutput},
},
tag::verify_tags,
};
use futures::{Future, FutureExt};
use mpc_tls::{LeaderCtrl, SessionKeys};
use mpz_common::Context;
use mpz_vm_core::Execute;
use std::{
pin::Pin,
sync::{
Arc,
mpsc::{Receiver, Sender, channel},
},
task::Poll,
};
use tls_client::ClientConnection;
use tlsn_core::transcript::TlsTranscript;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, instrument, trace, warn};
pub(crate) type MpcFuture =
Box<dyn Future<Output = Result<(Context, TlsTranscript), ProverError>> + Send>;
type FinalizeFuture =
Box<dyn Future<Output = Result<(InnerState, Context, TlsTranscript), ProverError>> + Send>;
pub(crate) struct MpcTlsClient {
sender: Sender<Command>,
state: State,
decrypt: bool,
}
enum State {
Start {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
receiver: Receiver<Command>,
},
Active {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
receiver: Receiver<Command>,
},
Busy {
mpc: Pin<MpcFuture>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
receiver: Receiver<Command>,
},
MpcStop {
mpc: Pin<MpcFuture>,
inner: Box<InnerState>,
},
CloseBusy {
mpc: Pin<MpcFuture>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
Finishing {
ctx: Context,
transcript: Box<TlsTranscript>,
fut: Pin<Box<dyn Future<Output = Result<Box<InnerState>, ProverError>> + Send>>,
},
Finalizing {
fut: Pin<FinalizeFuture>,
},
Finished,
Error,
}
impl MpcTlsClient {
pub(crate) fn new(
mpc: MpcFuture,
keys: SessionKeys,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
span: Span,
mpc_ctrl: LeaderCtrl,
tls: ClientConnection,
decrypt: bool,
) -> Self {
let inner = InnerState {
span,
tls,
vm,
keys,
mpc_ctrl,
client_closed: false,
mpc_stopped: false,
};
let (sender, receiver) = channel();
Self {
sender,
decrypt,
state: State::Start {
receiver,
mpc: Box::into_pin(mpc),
inner: Box::new(inner),
},
}
}
fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> {
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &mut self.state {
Some(&mut inner.tls)
} else {
None
}
}
fn inner_client(&self) -> Option<&ClientConnection> {
if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &self.state {
Some(&inner.tls)
} else {
None
}
}
}
impl TlsClient for MpcTlsClient {
type Error = ProverError;
fn wants_read_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_read()
} else {
false
}
}
fn wants_write_tls(&self) -> bool {
if let Some(client) = self.inner_client() {
client.wants_write()
} else {
false
}
}
fn read_tls(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_read()
{
client.read_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write_tls(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& client.wants_write()
{
client.write_tls(&mut buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn wants_read(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.plaintext_is_empty()
} else {
false
}
}
fn wants_write(&self) -> bool {
if let Some(client) = self.inner_client() {
!client.sendable_plaintext_is_full()
} else {
false
}
}
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.plaintext_is_empty()
{
client.read_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if let Some(client) = self.inner_client_mut()
&& !client.sendable_plaintext_is_full()
{
client.write_plaintext(buf).map_err(ProverError::from)
} else {
Ok(0)
}
}
fn client_close(&mut self) -> Result<(), Self::Error> {
self.sender
.send(Command::ClientClose)
.map_err(|_| ProverError::state("unable to close connection clientside"))
}
fn server_close(&mut self) -> Result<(), Self::Error> {
self.sender
.send(Command::ServerClose)
.map_err(|_| ProverError::state("unable to close connection serverside"))
}
fn handle(&self) -> ClientHandle {
ClientHandle {
sender: self.sender.clone(),
}
}
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
match std::mem::replace(&mut self.state, State::Error) {
State::Start {
mpc,
inner,
receiver,
} => {
trace!("inner client is starting");
self.state = State::Busy {
mpc,
fut: Box::pin(inner.start()),
receiver,
};
self.poll(cx)
}
State::Active {
mpc,
inner,
receiver,
} => {
trace!("inner client is active");
if !inner.tls.is_handshaking()
&& let Ok(cmd) = receiver.try_recv()
{
match cmd {
Command::ClientClose => {
self.state = State::Busy {
mpc,
fut: Box::pin(inner.client_close()),
receiver,
};
}
Command::ServerClose => {
std::mem::drop(receiver);
self.state = State::CloseBusy {
mpc,
fut: Box::pin(inner.server_close()),
};
}
Command::SetDecrypt(enable) => {
self.decrypt = enable;
self.state = State::Busy {
mpc,
fut: Box::pin(inner.set_decrypt(enable)),
receiver,
};
}
Command::IsDecrypting(sender) => {
_ = sender.send(self.decrypt);
self.state = State::Busy {
mpc,
fut: Box::pin(inner.run()),
receiver,
};
}
}
} else {
self.state = State::Busy {
mpc,
fut: Box::pin(inner.run()),
receiver,
};
}
self.poll(cx)
}
State::Busy {
mut mpc,
mut fut,
receiver,
} => {
trace!("inner client is busy");
let mpc_poll = mpc.as_mut().poll(cx)?;
assert!(
matches!(mpc_poll, Poll::Pending),
"mpc future should not be finished here"
);
match fut.as_mut().poll(cx)? {
Poll::Ready(inner) => {
self.state = State::Active {
mpc,
inner,
receiver,
};
}
Poll::Pending => self.state = State::Busy { mpc, fut, receiver },
}
Poll::Pending
}
State::MpcStop { mpc, inner } => {
trace!("inner client is stopping mpc");
self.state = State::CloseBusy {
mpc,
fut: Box::pin(inner.stop()),
};
self.poll(cx)
}
State::CloseBusy { mut mpc, mut fut } => {
trace!("inner client is busy closing");
match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) {
(Poll::Ready(inner), Poll::Ready((ctx, transcript))) => {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, transcript)),
};
self.poll(cx)
}
(Poll::Ready(inner), Poll::Pending) => {
self.state = State::MpcStop { mpc, inner };
Poll::Pending
}
(Poll::Pending, Poll::Ready((ctx, transcript))) => {
self.state = State::Finishing {
ctx,
transcript: Box::new(transcript),
fut,
};
Poll::Pending
}
(Poll::Pending, Poll::Pending) => {
self.state = State::CloseBusy { mpc, fut };
Poll::Pending
}
}
}
State::Finishing {
ctx,
transcript,
mut fut,
} => {
trace!("inner client is finishing");
if let Poll::Ready(inner) = fut.poll_unpin(cx)? {
self.state = State::Finalizing {
fut: Box::pin(inner.finalize(ctx, *transcript)),
};
self.poll(cx)
} else {
self.state = State::Finishing {
ctx,
transcript,
fut,
};
Poll::Pending
}
}
State::Finalizing { mut fut } => match fut.poll_unpin(cx) {
Poll::Ready(output) => {
let (inner, ctx, tls_transcript) = output?;
let InnerState { vm, keys, .. } = inner;
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
let output = TlsOutput {
ctx,
vm,
keys,
tls_transcript,
transcript,
};
self.state = State::Finished;
Poll::Ready(Ok(output))
}
Poll::Pending => {
self.state = State::Finalizing { fut };
Poll::Pending
}
},
State::Finished => Poll::Ready(Err(ProverError::state(
"mpc tls client polled again in finished state",
))),
State::Error => {
Poll::Ready(Err(ProverError::state("mpc tls client is in error state")))
}
}
}
}
struct InnerState {
span: Span,
tls: ClientConnection,
vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
keys: SessionKeys,
mpc_ctrl: LeaderCtrl,
client_closed: bool,
mpc_stopped: bool,
}
impl InnerState {
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.start().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
self.mpc_ctrl.enable_decryption(enable).await?;
self.run().await
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn client_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
if !self.client_closed {
debug!("sending close notify");
if let Err(e) = self.tls.send_close_notify().await {
warn!("failed to send close_notify to server: {}", e);
}
self.client_closed = true;
}
self.run().await
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn server_close(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
self.tls.server_closed().await?;
debug!("closed connection serverside");
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn stop(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
self.tls.process_new_packets().await?;
if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? {
self.mpc_ctrl.stop().await?;
self.mpc_stopped = true;
debug!("stopped mpc");
}
Ok(self)
}
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
async fn finalize(
self,
mut ctx: Context,
transcript: TlsTranscript,
) -> Result<(Self, Context, TlsTranscript), ProverError> {
{
let mut vm = self.vm.try_lock().expect("VM should not be locked");
// Finalize DEAP.
vm.finalize(&mut ctx).await.map_err(ProverError::mpc)?;
debug!("mpc finalized");
// Pull out ZK VM.
let mut zk = vm.zk();
// Prove tag verification of received records.
// The prover drops the proof output.
let _ = verify_tags(
&mut *zk,
(self.keys.server_write_key, self.keys.server_write_iv),
self.keys.server_write_mac_key,
*transcript.version(),
transcript.recv().to_vec(),
)
.map_err(ProverError::zk)?;
debug!("verified tags from server");
zk.execute_all(&mut ctx).await.map_err(ProverError::zk)?
}
debug!("MPC-TLS done");
Ok((self, ctx, transcript))
}
}

View File

@@ -0,0 +1,66 @@
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt};
use futures_plex::DuplexStream;
use std::{
pin::Pin,
task::{Context, Poll},
};
/// A TLS connection to a server.
///
/// This type implements [`AsyncRead`] and [`AsyncWrite`] and can be used to
/// communicate with a server using TLS.
///
/// # Note
///
/// This connection is closed on a best-effort basis if this is dropped. To
/// ensure a clean close, you should call
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
/// connection.
pub struct TlsConnection {
duplex: DuplexStream,
}
impl TlsConnection {
pub(crate) fn new(duplex: DuplexStream) -> Self {
Self { duplex }
}
}
impl Drop for TlsConnection {
fn drop(&mut self) {
if let Err(err) = futures::executor::block_on(self.duplex.close()) {
tracing::error!("error closing connection: {}", err);
}
}
}
impl AsyncRead for TlsConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let duplex = Pin::new(&mut self.duplex);
duplex.poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let duplex = Pin::new(&mut self.duplex);
duplex.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let duplex = Pin::new(&mut self.duplex);
duplex.poll_close(cx)
}
}

View File

@@ -0,0 +1,29 @@
use crate::prover::client::ClientHandle;
/// A controller for the prover.
///
/// Can be used to control the decryption of server traffic.
#[derive(Clone, Debug)]
pub struct ProverControl {
pub(crate) handle: ClientHandle,
}
impl ProverControl {
/// Returns whether the prover is decrypting the server traffic.
pub fn is_decrypting(&self) -> bool {
self.handle.is_decrypting()
}
/// Enables or disables the decryption of server traffic.
///
/// # Arguments
///
/// * `enable` - If decryption should be enabled or disabled.
pub fn enable_decryption(&self, enable: bool) -> Result<(), ControlError> {
self.handle.enable_decryption(enable)
}
}
#[derive(Debug, thiserror::Error)]
#[error("Unable to send control command to prover.")]
pub struct ControlError;

View File

@@ -49,6 +49,13 @@ impl ProverError {
{ {
Self::new(ErrorKind::Commit, source) Self::new(ErrorKind::Commit, source)
} }
pub(crate) fn state<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::State, source)
}
} }
#[derive(Debug)] #[derive(Debug)]
@@ -58,6 +65,7 @@ enum ErrorKind {
Zk, Zk,
Config, Config,
Commit, Commit,
State,
} }
impl fmt::Display for ProverError { impl fmt::Display for ProverError {
@@ -70,6 +78,7 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?, ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?, ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?, ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::State => f.write_str("state error")?,
} }
if let Some(source) = &self.source { if let Some(source) = &self.source {
@@ -86,8 +95,8 @@ impl From<std::io::Error> for ProverError {
} }
} }
impl From<tls_client_async::ConnectionError> for ProverError { impl From<tls_client::Error> for ProverError {
fn from(e: tls_client_async::ConnectionError) -> Self { fn from(e: tls_client::Error) -> Self {
Self::new(ErrorKind::Io, e) Self::new(ErrorKind::Io, e)
} }
} }

View File

@@ -1,32 +0,0 @@
//! This module collects futures which are used by the [Prover].
use super::{Prover, ProverControl, ProverError, state};
use futures::Future;
use std::pin::Pin;
/// Prover future which must be polled for the TLS connection to make progress.
pub struct ProverFuture {
#[allow(clippy::type_complexity)]
pub(crate) fut: Pin<
Box<dyn Future<Output = Result<Prover<state::Committed>, ProverError>> + Send + 'static>,
>,
pub(crate) ctrl: ProverControl,
}
impl ProverFuture {
/// Returns a controller for the prover for advanced functionality.
pub fn control(&self) -> ProverControl {
self.ctrl.clone()
}
}
impl Future for ProverFuture {
type Output = Result<Prover<state::Committed>, ProverError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.fut.as_mut().poll(cx)
}
}

View File

@@ -2,6 +2,7 @@
use std::sync::Arc; use std::sync::Arc;
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsLeader, SessionKeys}; use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context; use mpz_common::Context;
use tlsn_core::{ use tlsn_core::{
@@ -14,6 +15,10 @@ use tokio::sync::Mutex;
use crate::{ use crate::{
mpz::{ProverMpc, ProverZk}, mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture}, mux::{MuxControl, MuxFuture},
prover::{
ProverError,
client::{TlsClient, TlsOutput},
},
}; };
/// Entry state /// Entry state
@@ -24,6 +29,7 @@ opaque_debug::implement!(Initialized);
/// State after the verifier has accepted the proposed TLS commitment protocol /// State after the verifier has accepted the proposed TLS commitment protocol
/// configuration and preprocessing has completed. /// configuration and preprocessing has completed.
pub struct CommitAccepted { pub struct CommitAccepted {
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader, pub(crate) mpc_tls: MpcTlsLeader,
@@ -33,8 +39,49 @@ pub struct CommitAccepted {
opaque_debug::implement!(CommitAccepted); opaque_debug::implement!(CommitAccepted);
/// State when the TLS client has been setup.
pub struct Setup {
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
pub(crate) client_io: DuplexStream,
}
opaque_debug::implement!(Setup);
pin_project_lite::pin_project! {
/// State during the MPC-TLS connection.
#[project = ConnectedProj]
pub struct Connected<S, T> {
#[pin]
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) server_name: ServerName,
pub(crate) tls_client: Box<dyn TlsClient<Error = ProverError> + Send>,
#[pin]
pub(crate) client_io: DuplexStream,
pub(crate) output: Option<TlsOutput>,
#[pin]
pub(crate) server_socket: S,
#[pin]
pub(crate) verifier_socket: T,
#[pin]
pub(crate) tls_client_to_server_buf: DuplexStream,
#[pin]
pub(crate) server_to_tls_client_buf: DuplexStream,
pub(crate) client_closed: bool,
pub(crate) server_closed: bool
}
}
opaque_debug::implement!(Connected<S, T>);
/// State after the TLS transcript has been committed. /// State after the TLS transcript has been committed.
pub struct Committed { pub struct Committed {
pub(crate) verifier_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
@@ -52,11 +99,15 @@ pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {} impl ProverState for Initialized {}
impl ProverState for CommitAccepted {} impl ProverState for CommitAccepted {}
impl ProverState for Setup {}
impl<S, T> ProverState for Connected<S, T> {}
impl ProverState for Committed {} impl ProverState for Committed {}
mod sealed { mod sealed {
pub trait Sealed {} pub trait Sealed {}
impl Sealed for super::Initialized {} impl Sealed for super::Initialized {}
impl Sealed for super::CommitAccepted {} impl Sealed for super::CommitAccepted {}
impl Sealed for super::Setup {}
impl<S, T> Sealed for super::Connected<S, T> {}
impl Sealed for super::Committed {} impl Sealed for super::Committed {}
} }

143
crates/tlsn/src/utils.rs Normal file
View File

@@ -0,0 +1,143 @@
//! Execution context.
use std::{
io::ErrorKind,
pin::Pin,
task::{Context, Poll},
};
use futures::{AsyncRead, AsyncWrite, future::FusedFuture};
use futures_plex::DuplexStream;
use mpz_common::context::Multithread;
use crate::mux::MuxControl;
/// Maximum concurrency for multi-threaded context.
pub(crate) const MAX_CONCURRENCY: usize = 8;
/// Builds a multi-threaded context with the given muxer.
pub(crate) fn build_mt_context(mux: MuxControl) -> Multithread {
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
#[cfg(all(feature = "web", target_arch = "wasm32"))]
let builder = builder.spawn_handler(|f| {
let _ = web_spawn::spawn(f);
Ok(())
});
builder.build().unwrap()
}
/// Polls the future while copying bytes between two duplex streams.
///
/// Returns as soon as the future is ready, without closing IO.
pub(crate) async fn await_with_copy_io<'a, S, T>(
mut fut: Pin<Box<dyn FusedFuture<Output = T> + Send + 'a>>,
io: S,
duplex: &mut DuplexStream,
) -> T
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let mut copy = CopyIo::new(io, duplex);
loop {
futures::select! {
_ = copy => (),
output = fut => break output
}
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub(crate) struct CopyIo<'a, S> {
#[pin]
io: S,
#[pin]
duplex: &'a mut DuplexStream,
io_done: bool,
duplex_done: bool,
}
}
impl<'a, S> CopyIo<'a, S> {
pub(crate) fn new(io: S, duplex: &'a mut DuplexStream) -> Self {
Self {
io,
duplex,
io_done: false,
duplex_done: false,
}
}
}
impl<'a, S> Future for CopyIo<'a, S>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
type Output = std::io::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
let mut is_pending = true;
if !*this.duplex_done {
match this.duplex.poll_read_to(cx, this.io.as_mut()) {
Poll::Ready(Ok(read)) if read > 0 => is_pending = false,
Poll::Ready(Ok(_)) => {
is_pending = false;
*this.duplex_done = true;
}
Poll::Ready(Err(err))
if err.kind() == ErrorKind::BrokenPipe
|| err.kind() == ErrorKind::ConnectionReset
|| err.kind() == ErrorKind::NotConnected =>
{
is_pending = false;
*this.duplex_done = true;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => (),
}
}
if !*this.io_done {
match this.duplex.poll_write_from(cx, this.io.as_mut()) {
Poll::Ready(Ok(write)) if write > 0 => is_pending = false,
Poll::Ready(Ok(_)) => {
is_pending = false;
*this.io_done = true;
}
Poll::Ready(Err(err))
if err.kind() == ErrorKind::BrokenPipe
|| err.kind() == ErrorKind::ConnectionReset
|| err.kind() == ErrorKind::NotConnected =>
{
is_pending = false;
*this.io_done = true
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => (),
}
}
if *this.io_done || *this.duplex_done {
return Poll::Ready(Ok(()));
} else if is_pending {
return Poll::Pending;
}
}
}
}
impl<'a, S> FusedFuture for CopyIo<'a, S>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
fn is_terminated(&self) -> bool {
self.duplex_done || self.io_done
}
}

View File

@@ -10,14 +10,14 @@ pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier}; pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
use crate::{ use crate::{
Role, BUF_CAP, Role,
context::build_mt_context,
mpz::{VerifierDeps, build_verifier_deps, translate_keys}, mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg}, msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux, mux::attach_mux,
tag::verify_tags, tag::verify_tags,
utils::{CopyIo, await_with_copy_io, build_mt_context},
}; };
use futures::{AsyncRead, AsyncWrite, TryFutureExt}; use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, TryFutureExt};
use mpz_vm_core::prelude::*; use mpz_vm_core::prelude::*;
use serio::{SinkExt, stream::IoStreamExt}; use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{ use tlsn_core::{
@@ -66,48 +66,73 @@ impl Verifier<state::Initialized> {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `socket` - The socket to the prover. /// * `prover_io` - The IO to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>( pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin>(
self, self,
socket: S, mut prover_io: S,
) -> Result<Verifier<state::CommitStart>, VerifierError> { ) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier); let (duplex_a, mut duplex_b) = futures_plex::duplex(BUF_CAP);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Receives protocol configuration from prover to perform compatibility check. let (mut mux_fut, mux_ctrl) = attach_mux(duplex_a, Role::Verifier);
let TlsCommitRequestMsg { request, version } = let mut mt = build_mt_context(mux_ctrl.clone());
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
let fut = Box::pin(
async {
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Receives protocol configuration from prover to perform compatibility check.
let TlsCommitRequestMsg { request, version } =
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
Ok::<_, VerifierError>((request, version, ctx))
}
.fuse(),
);
let (request, version, mut ctx) =
await_with_copy_io(fut, &mut prover_io, &mut duplex_b).await?;
if version != *crate::VERSION { if version != *crate::VERSION {
let msg = format!( let msg = format!(
"prover version does not match with verifier: {version} != {}", "prover version does not match with verifier: {version} != {}",
*crate::VERSION *crate::VERSION
); );
mux_fut
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
.await?;
// Wait for the prover to correctly close the connection. let fut = Box::pin(
if !mux_fut.is_complete() { async {
mux_ctrl.close(); mux_fut
mux_fut.await?; .poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
} .await?;
return Err(VerifierError::config(msg)); // Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
}
Err(VerifierError::config(msg))
}
.fuse(),
);
let copy = CopyIo::new(prover_io, &mut duplex_b).map_err(VerifierError::from);
let (config_err, _) = futures::try_join!(fut, copy)?;
return Err(config_err);
} }
Ok(Verifier { let verifier = Verifier {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::CommitStart { state: state::CommitStart {
prover_io: Some(duplex_b),
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
ctx, ctx,
request, request,
}, },
}) };
Ok(verifier)
} }
} }
@@ -118,13 +143,36 @@ impl Verifier<state::CommitStart> {
} }
/// Accepts the proposed protocol configuration. /// Accepts the proposed protocol configuration.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn accept<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.accept_inner().fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> { async fn accept_inner(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let state::CommitStart { let state::CommitStart {
prover_io,
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mut ctx, mut ctx,
request, request,
..
} = self.state; } = self.state;
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?; mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
@@ -151,6 +199,7 @@ impl Verifier<state::CommitStart> {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::CommitAccepted { state: state::CommitAccepted {
prover_io,
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
mpc_tls, mpc_tls,
@@ -161,8 +210,31 @@ impl Verifier<state::CommitStart> {
} }
/// Rejects the proposed protocol configuration. /// Rejects the proposed protocol configuration.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
/// * `msg` - The optional rejection message.
pub async fn reject<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
msg: Option<&str>,
) -> Result<(), VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = self.reject_inner(msg);
let copy = CopyIo::new(prover_io, &mut duplex).map_err(VerifierError::from);
futures::try_join!(fut, copy)?;
Ok(())
}
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn reject(self, msg: Option<&str>) -> Result<(), VerifierError> { async fn reject_inner(self, msg: Option<&str>) -> Result<(), VerifierError> {
let state::CommitStart { let state::CommitStart {
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
@@ -186,9 +258,31 @@ impl Verifier<state::CommitStart> {
impl Verifier<state::CommitAccepted> { impl Verifier<state::CommitAccepted> {
/// Runs the verifier until the TLS connection is closed. /// Runs the verifier until the TLS connection is closed.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn run<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<Verifier<state::Committed>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.run_inner().fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> { async fn run_inner(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::CommitAccepted { let state::CommitAccepted {
prover_io,
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mpc_tls, mpc_tls,
@@ -232,6 +326,7 @@ impl Verifier<state::CommitAccepted> {
) )
.map_err(VerifierError::zk)?; .map_err(VerifierError::zk)?;
debug!("verifying tags");
mux_fut mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk)) .poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?; .await?;
@@ -241,10 +336,12 @@ impl Verifier<state::CommitAccepted> {
// authenticated from the verifier's perspective. // authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?; tag_proof.verify().map_err(VerifierError::zk)?;
debug!("MPC-TLS done");
Ok(Verifier { Ok(Verifier {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Committed { state: state::Committed {
prover_io,
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
ctx, ctx,
@@ -263,9 +360,31 @@ impl Verifier<state::Committed> {
} }
/// Begins verification of statements from the prover. /// Begins verification of statements from the prover.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn verify<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<Verifier<state::Verify>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.verify_inner().fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> { async fn verify_inner(self) -> Result<Verifier<state::Verify>, VerifierError> {
let state::Committed { let state::Committed {
prover_io,
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mut ctx, mut ctx,
@@ -286,6 +405,7 @@ impl Verifier<state::Committed> {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Verify { state: state::Verify {
prover_io,
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
ctx, ctx,
@@ -300,18 +420,36 @@ impl Verifier<state::Committed> {
} }
/// Closes the connection with the prover. /// Closes the connection with the prover.
///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn close(self) -> Result<(), VerifierError> { pub async fn close<S: AsyncWrite + AsyncRead + Send + Unpin>(
let state::Committed { mut self,
mux_ctrl, mux_fut, .. mut prover_io: S,
} = self.state; ) -> Result<(), VerifierError> {
let state::Committed { mux_fut, .. } = self.state;
// Wait for the prover to correctly close the connection. let mut duplex = self
if !mux_fut.is_complete() { .state
mux_ctrl.close(); .prover_io
mux_fut.await?; .take()
} .expect("duplex should be available");
duplex.close().await?;
let fut: Box<dyn Future<Output = Result<(), VerifierError>> + Send + Unpin> =
if mux_fut.is_complete() {
Box::new(futures::future::ready(Ok::<_, VerifierError>(())))
} else {
Box::new(mux_fut.map_err(VerifierError::from))
};
let copy = CopyIo::new(&mut prover_io, &mut duplex).map_err(VerifierError::from);
futures::try_join!(fut, copy)?;
prover_io.write_all(b"close").await?;
Ok(()) Ok(())
} }
} }
@@ -323,10 +461,32 @@ impl Verifier<state::Verify> {
} }
/// Accepts the proving request. /// Accepts the proving request.
pub async fn accept( ///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
pub async fn accept<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.accept_inner().fuse());
let (output, mut verifier) = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok((output, verifier))
}
async fn accept_inner(
self, self,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> { ) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let state::Verify { let state::Verify {
prover_io,
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mut ctx, mut ctx,
@@ -362,6 +522,7 @@ impl Verifier<state::Verify> {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Committed { state: state::Committed {
prover_io,
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
ctx, ctx,
@@ -374,11 +535,35 @@ impl Verifier<state::Verify> {
} }
/// Rejects the proving request. /// Rejects the proving request.
pub async fn reject( ///
/// # Arguments
///
/// * `prover_io` - The IO to the prover.
/// * `msg` - The optional rejection message.
pub async fn reject<S: AsyncWrite + AsyncRead + Send + Unpin>(
mut self,
prover_io: S,
msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> {
let mut duplex = self
.state
.prover_io
.take()
.expect("duplex should be available");
let fut = Box::pin(self.reject_inner(msg).fuse());
let mut verifier = await_with_copy_io(fut, prover_io, &mut duplex).await?;
verifier.state.prover_io = Some(duplex);
Ok(verifier)
}
async fn reject_inner(
self, self,
msg: Option<&str>, msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> { ) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Verify { let state::Verify {
prover_io,
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mut ctx, mut ctx,
@@ -396,6 +581,7 @@ impl Verifier<state::Verify> {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Committed { state: state::Committed {
prover_io,
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
ctx, ctx,

View File

@@ -3,6 +3,7 @@
use std::sync::Arc; use std::sync::Arc;
use crate::mux::{MuxControl, MuxFuture}; use crate::mux::{MuxControl, MuxFuture};
use futures_plex::DuplexStream;
use mpc_tls::{MpcTlsFollower, SessionKeys}; use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context; use mpz_common::Context;
use tlsn_core::{ use tlsn_core::{
@@ -25,6 +26,7 @@ opaque_debug::implement!(Initialized);
/// State after receiving protocol configuration from the prover. /// State after receiving protocol configuration from the prover.
pub struct CommitStart { pub struct CommitStart {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
@@ -36,6 +38,7 @@ opaque_debug::implement!(CommitStart);
/// State after accepting the proposed TLS commitment protocol configuration and /// State after accepting the proposed TLS commitment protocol configuration and
/// performing preprocessing. /// performing preprocessing.
pub struct CommitAccepted { pub struct CommitAccepted {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsFollower, pub(crate) mpc_tls: MpcTlsFollower,
@@ -47,6 +50,7 @@ opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed. /// State after the TLS transcript has been committed.
pub struct Committed { pub struct Committed {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
@@ -59,6 +63,7 @@ opaque_debug::implement!(Committed);
/// State after receiving a proving request. /// State after receiving a proving request.
pub struct Verify { pub struct Verify {
pub(crate) prover_io: Option<DuplexStream>,
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,

View File

@@ -110,7 +110,11 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
) -> (Transcript, ProverOutput) { ) -> (Transcript, ProverOutput) {
let (client_socket, server_socket) = tokio::io::duplex(2 << 16); let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
let server_task = tokio::spawn(bind(server_socket.compat())); let client_socket = client_socket.compat();
let server_socket = server_socket.compat();
let mut verifier_socket = verifier_socket.compat();
let server_task = tokio::spawn(bind(server_socket));
let prover = Prover::new(ProverConfig::builder().build().unwrap()) let prover = Prover::new(ProverConfig::builder().build().unwrap())
.commit( .commit(
@@ -126,13 +130,13 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
) )
.build() .build()
.unwrap(), .unwrap(),
verifier_socket.compat(), &mut verifier_socket,
) )
.await .await
.unwrap(); .unwrap();
let (mut tls_connection, prover_fut) = prover let (mut tls_connection, prover) = prover
.connect( .setup(
TlsClientConfig::builder() TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap())) .server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.root_store(RootCertStore { .root_store(RootCertStore {
@@ -140,24 +144,23 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
}) })
.build() .build()
.unwrap(), .unwrap(),
client_socket.compat(),
) )
.await
.unwrap(); .unwrap();
let prover_task = tokio::spawn(prover_fut);
let prover_task = tokio::spawn(prover.run(client_socket, verifier_socket));
tls_connection tls_connection
.write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
.await .await
.unwrap(); .unwrap();
tls_connection.close().await.unwrap();
let mut response = vec![0u8; 1024]; let mut response = vec![0u8; 1024];
tls_connection.read_to_end(&mut response).await.unwrap(); tls_connection.read_to_end(&mut response).await.unwrap();
tls_connection.close().await.unwrap();
let _ = server_task.await.unwrap(); let _ = server_task.await.unwrap();
let mut prover = prover_task.await.unwrap().unwrap(); let (mut prover, _, mut verifier_socket) = prover_task.await.unwrap().unwrap();
let sent_tx_len = prover.transcript().sent().len(); let sent_tx_len = prover.transcript().sent().len();
let recv_tx_len = prover.transcript().received().len(); let recv_tx_len = prover.transcript().received().len();
@@ -196,8 +199,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build().unwrap(); let config = builder.build().unwrap();
let transcript = prover.transcript().clone(); let transcript = prover.transcript().clone();
let output = prover.prove(&config).await.unwrap(); let output = prover.prove(&config, &mut verifier_socket).await.unwrap();
prover.close().await.unwrap(); prover.close(&mut verifier_socket).await.unwrap();
(transcript, output) (transcript, output)
} }
@@ -206,6 +209,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>( async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T, socket: T,
) -> VerifierOutput { ) -> VerifierOutput {
let mut socket = socket.compat();
let verifier = Verifier::new( let verifier = Verifier::new(
VerifierConfig::builder() VerifierConfig::builder()
.root_store(RootCertStore { .root_store(RootCertStore {
@@ -216,18 +221,24 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
); );
let verifier = verifier let verifier = verifier
.commit(socket.compat()) .commit(&mut socket)
.await .await
.unwrap() .unwrap()
.accept() .accept(&mut socket)
.await .await
.unwrap() .unwrap()
.run() .run(&mut socket)
.await .await
.unwrap(); .unwrap();
let (output, verifier) = verifier.verify().await.unwrap().accept().await.unwrap(); let (output, verifier) = verifier
verifier.close().await.unwrap(); .verify(&mut socket)
.await
.unwrap()
.accept(&mut socket)
.await
.unwrap();
verifier.close(&mut socket).await.unwrap();
output output
} }

View File

@@ -23,9 +23,9 @@ no-bundler = ["web-spawn/no-bundler"]
tlsn-core = { workspace = true } tlsn-core = { workspace = true }
tlsn = { workspace = true, features = ["web", "mozilla-certs"] } tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
tlsn-server-fixture-certs = { workspace = true } tlsn-server-fixture-certs = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true } tlsn-tls-core = { workspace = true }
async_io_stream = { version = "0.3" }
bincode = { workspace = true } bincode = { workspace = true }
console_error_panic_hook = { version = "0.1" } console_error_panic_hook = { version = "0.1" }
enum-try-as-inner = { workspace = true } enum-try-as-inner = { workspace = true }

View File

@@ -2,11 +2,11 @@ mod config;
pub use config::ProverConfig; pub use config::ProverConfig;
use async_io_stream::IoStream;
use enum_try_as_inner::EnumTryAsInner; use enum_try_as_inner::EnumTryAsInner;
use futures::TryFutureExt; use futures::TryFutureExt;
use http_body_util::{BodyExt, Full}; use http_body_util::{BodyExt, Full};
use hyper::body::Bytes; use hyper::body::Bytes;
use tls_client_async::TlsConnection;
use tlsn::{ use tlsn::{
config::{ config::{
prove::ProveConfig, prove::ProveConfig,
@@ -14,13 +14,13 @@ use tlsn::{
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig}, tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
}, },
connection::ServerName, connection::ServerName,
prover::{state, Prover}, prover::{state, Prover, TlsConnection},
webpki::{CertificateDer, PrivateKeyDer, RootCertStore}, webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
}; };
use tracing::info; use tracing::info;
use wasm_bindgen::{prelude::*, JsError}; use wasm_bindgen::{prelude::*, JsError};
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use ws_stream_wasm::WsMeta; use ws_stream_wasm::{WsMeta, WsStreamIo};
use crate::{io::FuturesIo, types::*}; use crate::{io::FuturesIo, types::*};
@@ -36,8 +36,14 @@ pub struct JsProver {
#[derive_err(Debug)] #[derive_err(Debug)]
enum State { enum State {
Initialized(Prover<state::Initialized>), Initialized(Prover<state::Initialized>),
CommitAccepted(Prover<state::CommitAccepted>), CommitAccepted {
Committed(Prover<state::Committed>), prover: Prover<state::CommitAccepted>,
verifier_conn: IoStream<WsStreamIo, Vec<u8>>,
},
Committed {
prover: Prover<state::Committed>,
verifier_conn: IoStream<WsStreamIo, Vec<u8>>,
},
Complete, Complete,
Error, Error,
} }
@@ -96,12 +102,16 @@ impl JsProver {
info!("connecting to verifier"); info!("connecting to verifier");
let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?; let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?;
let mut verifier_conn = verifier_conn.into_io();
info!("connected to verifier"); info!("connected to verifier");
let prover = prover.commit(config, verifier_conn.into_io()).await?; let prover = prover.commit(config, &mut verifier_conn).await?;
self.state = State::CommitAccepted(prover); self.state = State::CommitAccepted {
prover,
verifier_conn,
};
Ok(()) Ok(())
} }
@@ -112,7 +122,7 @@ impl JsProver {
ws_proxy_url: &str, ws_proxy_url: &str,
request: HttpRequest, request: HttpRequest,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let prover = self.state.take().try_into_commit_accepted()?; let (prover, mut verifier_conn) = self.state.take().try_into_commit_accepted()?;
let mut builder = TlsClientConfig::builder() let mut builder = TlsClientConfig::builder()
.server_name(ServerName::Dns( .server_name(ServerName::Dns(
@@ -145,35 +155,41 @@ impl JsProver {
info!("connecting to server"); info!("connecting to server");
let (_, server_conn) = WsMeta::connect(ws_proxy_url, None).await?; let (_, server_conn) = WsMeta::connect(ws_proxy_url, None).await?;
let mut server_conn = server_conn.into_io();
info!("connected to server"); info!("connected to server");
let (tls_conn, prover_fut) = prover.connect(config, server_conn.into_io()).await?; let (tls_conn, prover) = prover.setup(config)?;
let mut prover = prover.connect(&mut server_conn, &mut verifier_conn);
info!("sending request"); info!("sending request");
let (response, prover) = futures::try_join!( let (response, _) = futures::try_join!(
send_request(tls_conn, request), send_request(tls_conn, request),
prover_fut.map_err(Into::into) (&mut prover).map_err(Into::into)
)?; )?;
let prover = prover.finish()?;
info!("response received"); info!("response received");
self.state = State::Committed(prover); self.state = State::Committed {
prover,
verifier_conn,
};
Ok(response) Ok(response)
} }
/// Returns the transcript. /// Returns the transcript.
pub fn transcript(&self) -> Result<Transcript> { pub fn transcript(&self) -> Result<Transcript> {
let prover = self.state.try_as_committed()?; let (prover, _) = self.state.try_as_committed()?;
Ok(Transcript::from(prover.transcript())) Ok(Transcript::from(prover.transcript()))
} }
/// Reveals data to the verifier and finalizes the protocol. /// Reveals data to the verifier and finalizes the protocol.
pub async fn reveal(&mut self, reveal: Reveal) -> Result<()> { pub async fn reveal(&mut self, reveal: Reveal) -> Result<()> {
let mut prover = self.state.take().try_into_committed()?; let (mut prover, mut verifier_conn) = self.state.take().try_into_committed()?;
info!("revealing data"); info!("revealing data");
@@ -193,8 +209,8 @@ impl JsProver {
let config = builder.build()?; let config = builder.build()?;
prover.prove(&config).await?; prover.prove(&config, &mut verifier_conn).await?;
prover.close().await?; prover.close(&mut verifier_conn).await?;
info!("Finalized"); info!("Finalized");

View File

@@ -2,6 +2,7 @@ mod config;
pub use config::VerifierConfig; pub use config::VerifierConfig;
use async_io_stream::IoStream;
use enum_try_as_inner::EnumTryAsInner; use enum_try_as_inner::EnumTryAsInner;
use tlsn::{ use tlsn::{
config::tls_commit::TlsCommitProtocolConfig, config::tls_commit::TlsCommitProtocolConfig,
@@ -12,7 +13,7 @@ use tlsn::{
}; };
use tracing::info; use tracing::info;
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use ws_stream_wasm::{WsMeta, WsStream}; use ws_stream_wasm::{WsMeta, WsStreamIo};
use crate::types::VerifierOutput; use crate::types::VerifierOutput;
@@ -26,9 +27,13 @@ pub struct JsVerifier {
#[derive(EnumTryAsInner)] #[derive(EnumTryAsInner)]
#[derive_err(Debug)] #[derive_err(Debug)]
#[allow(unused_assignments)]
enum State { enum State {
Initialized(Verifier<state::Initialized>), Initialized(Verifier<state::Initialized>),
Connected((Verifier<state::Initialized>, WsStream)), Connected {
verifier: Verifier<state::Initialized>,
prover_conn: IoStream<WsStreamIo, Vec<u8>>,
},
Complete, Complete,
Error, Error,
} }
@@ -66,19 +71,23 @@ impl JsVerifier {
info!("Connecting to prover"); info!("Connecting to prover");
let (_, prover_conn) = WsMeta::connect(prover_url, None).await?; let (_, prover_conn) = WsMeta::connect(prover_url, None).await?;
let prover_conn = prover_conn.into_io();
info!("Connected to prover"); info!("Connected to prover");
self.state = State::Connected((verifier, prover_conn)); self.state = State::Connected {
verifier,
prover_conn,
};
Ok(()) Ok(())
} }
/// Verifies the connection and finalizes the protocol. /// Verifies the connection and finalizes the protocol.
pub async fn verify(&mut self) -> Result<VerifierOutput> { pub async fn verify(&mut self) -> Result<VerifierOutput> {
let (verifier, prover_conn) = self.state.take().try_into_connected()?; let (verifier, mut prover_conn) = self.state.take().try_into_connected()?;
let verifier = verifier.commit(prover_conn.into_io()).await?; let verifier = verifier.commit(&mut prover_conn).await?;
let request = verifier.request(); let request = verifier.request();
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol() else { let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol() else {
@@ -98,11 +107,15 @@ impl JsVerifier {
}; };
if reject.is_some() { if reject.is_some() {
verifier.reject(reject).await?; verifier.reject(&mut prover_conn, reject).await?;
return Err(JsError::new("protocol configuration rejected")); return Err(JsError::new("protocol configuration rejected"));
} }
let verifier = verifier.accept().await?.run().await?; let verifier = verifier
.accept(&mut prover_conn)
.await?
.run(&mut prover_conn)
.await?;
let sent = verifier let sent = verifier
.tls_transcript() .tls_transcript()
@@ -129,8 +142,12 @@ impl JsVerifier {
}, },
}; };
let (output, verifier) = verifier.verify().await?.accept().await?; let (output, verifier) = verifier
verifier.close().await?; .verify(&mut prover_conn)
.await?
.accept(&mut prover_conn)
.await?;
verifier.close(&mut prover_conn).await?;
self.state = State::Complete; self.state = State::Complete;