diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..581e977 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "rust-analyzer.linkedProjects": [ + "interactive-demo/verifier-rs/Cargo.toml", + "interactive-demo/prover-rs/Cargo.toml" + ], +} \ No newline at end of file diff --git a/interactive-demo/.gitignore b/interactive-demo/.gitignore new file mode 100644 index 0000000..764515c --- /dev/null +++ b/interactive-demo/.gitignore @@ -0,0 +1,2 @@ +**/target/ +**/Cargo.lock diff --git a/interactive-demo/README.md b/interactive-demo/README.md new file mode 100644 index 0000000..330de72 --- /dev/null +++ b/interactive-demo/README.md @@ -0,0 +1,38 @@ +# Test Rust Prover + +1. Start the verifier: +```bash +cd verifier-rs; cargo run --release +``` +2. Run the prover: +```bash +cd prover-rs; cargo run --release +``` + +# Test Browser Prover +1. Start the verifier: +```bash +cd verifier-rs; cargo run --release +``` +2. Since a web browser doesn't have the ability to make TCP connection, we need to use a websocket proxy server to access . +```bash +cargo install wstcp + +wstcp --bind-addr 127.0.0.1:55688 swapi.dev:443 +``` +3. Run the prover + 1. Build tlsn-js + ```bash + cd .. + npm i + npm run build + npm link + ``` + 2. Build demo prover-ts + ```bash + cd prover-ts + npm i + npm link + npm run dev + ``` + 3. Open and click **Start Prover** \ No newline at end of file diff --git a/interactive-demo/prover-rs/Cargo.toml b/interactive-demo/prover-rs/Cargo.toml new file mode 100644 index 0000000..079dddc --- /dev/null +++ b/interactive-demo/prover-rs/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "interactive-networked-prover" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-tungstenite = { version = "0.25", features = ["tokio-runtime"] } +futures = "0.3" +http = "1.1" +http-body-util = "0.1" +hyper = {version = "1.1", features = ["client", "http1"]} +hyper-util = {version = "0.1", features = ["full"]} +regex = "1.10.3" +tokio = {version = "1", features = [ + "rt", + "rt-multi-thread", + "macros", + "net", + "io-std", + "fs", +]} +tokio-util = { version = "0.7", features = ["compat"] } +tracing = "0.1.40" +tracing-subscriber = { version ="0.3.18", features = ["env-filter"] } +uuid = { version = "1.4.1", features = ["v4", "fast-rng"] } +ws_stream_tungstenite = { version = "0.13", features = ["tokio_io"] } + +tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-core" } +tlsn-prover = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-prover" } +tlsn-common = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-common" } \ No newline at end of file diff --git a/interactive-demo/prover-rs/README.md b/interactive-demo/prover-rs/README.md new file mode 100644 index 0000000..0b92d86 --- /dev/null +++ b/interactive-demo/prover-rs/README.md @@ -0,0 +1,9 @@ +## Interactive Prover +An implementation of the interactive prover in Rust. + +## Running the prover +1. Configure this prover setting via the global variables defined in [main.rs](./src/main.rs) — please ensure that the hardcoded `SERVER_URL` and `VERIFICATION_SESSION_ID` have the same values on the verifier side. +2. Start the prover by running the following in a terminal at the root of this crate. +```bash +cargo run --release +``` diff --git a/interactive-demo/prover-rs/src/main.rs b/interactive-demo/prover-rs/src/main.rs new file mode 100644 index 0000000..3633585 --- /dev/null +++ b/interactive-demo/prover-rs/src/main.rs @@ -0,0 +1,172 @@ +use async_tungstenite::{tokio::connect_async_with_config, tungstenite::protocol::WebSocketConfig}; +use http_body_util::Empty; +use hyper::{body::Bytes, Request, StatusCode, Uri}; +use hyper_util::rt::TokioIo; +use regex::Regex; +use tlsn_common::config::ProtocolConfig; +use tlsn_core::transcript::Idx; +use tlsn_prover::{state::Prove, Prover, ProverConfig}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::{debug, info}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; +use ws_stream_tungstenite::WsStream; + +const TRACING_FILTER: &str = "INFO"; + +const VERIFIER_HOST: &str = "localhost"; +const VERIFIER_PORT: u16 = 9816; +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + +const SECRET: &str = "TLSNotary's private key 🤡"; +/// Make sure the following url's domain is the same as SERVER_DOMAIN on the verifier side +const SERVER_URL: &str = "https://swapi.dev/api/people/1"; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| TRACING_FILTER.into())) + .with(tracing_subscriber::fmt::layer()) + .init(); + + run_prover(VERIFIER_HOST, VERIFIER_PORT, SERVER_URL).await; +} + +async fn run_prover(verifier_host: &str, verifier_port: u16, server_uri: &str) { + info!("Sending websocket request..."); + let request = http::Request::builder() + .uri(format!("ws://{}:{}/verify", verifier_host, verifier_port,)) + .header("Host", verifier_host) + .header("Sec-WebSocket-Key", uuid::Uuid::new_v4().to_string()) + .header("Sec-WebSocket-Version", "13") + .header("Connection", "Upgrade") + .header("Upgrade", "Websocket") + .body(()) + .unwrap(); + + let (verifier_ws_stream, _) = + connect_async_with_config(request, Some(WebSocketConfig::default())) + .await + .unwrap(); + + info!("Websocket connection established!"); + let verifier_ws_socket = WsStream::new(verifier_ws_stream); + prover(verifier_ws_socket, server_uri).await; + info!("Proving is successful!"); +} + +async fn prover(verifier_socket: T, uri: &str) { + debug!("Starting proving..."); + + let uri = uri.parse::().unwrap(); + assert_eq!(uri.scheme().unwrap().as_str(), "https"); + let server_domain = uri.authority().unwrap().host(); + let server_port = uri.port_u16().unwrap_or(443); + + // Create prover and connect to verifier. + // + // Perform the setup phase with the verifier. + let prover = Prover::new( + ProverConfig::builder() + .server_name(server_domain) + .protocol_config( + ProtocolConfig::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(), + ) + .build() + .unwrap(), + ) + .setup(verifier_socket.compat()) + .await + .unwrap(); + + // Connect to TLS Server. + let tls_client_socket = tokio::net::TcpStream::connect((server_domain, server_port)) + .await + .unwrap(); + + // Pass server connection into the prover. + let (mpc_tls_connection, prover_fut) = + prover.connect(tls_client_socket.compat()).await.unwrap(); + + // Wrap the connection in a TokioIo compatibility layer to use it with hyper. + let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); + + // Spawn the Prover to run in the background. + let prover_task = tokio::spawn(prover_fut); + + // MPC-TLS Handshake. + let (mut request_sender, connection) = + hyper::client::conn::http1::handshake(mpc_tls_connection) + .await + .unwrap(); + + tokio::spawn(connection); + + // MPC-TLS: Send Request and wait for Response. + info!("Send Request and wait for Response"); + let request = Request::builder() + .uri(uri.clone()) + .header("Host", server_domain) + .header("Connection", "close") + .header("Secret", SECRET) + .method("GET") + .body(Empty::::new()) + .unwrap(); + let response = request_sender.send_request(request).await.unwrap(); + + debug!("TLS response: {:?}", response); + assert!(response.status() == StatusCode::OK); + + // Create proof for the Verifier. + let mut prover = prover_task.await.unwrap().unwrap().start_prove(); + + let idx_sent = redact_and_reveal_sent_data(&mut prover); + let idx_recv = redact_and_reveal_received_data(&mut prover); + + // Reveal parts of the transcript + prover.prove_transcript(idx_sent, idx_recv).await.unwrap(); + + // Finalize. + prover.finalize().await.unwrap() +} + +/// Redacts and reveals received data to the verifier. +fn redact_and_reveal_received_data(prover: &mut Prover) -> Idx { + let recv_transcript = prover.transcript().received(); + let recv_transcript_len = recv_transcript.len(); + + // Get the homeworld from the received data. + let received_string = String::from_utf8(recv_transcript.to_vec()).unwrap(); + debug!("Received data: {}", received_string); + let re = Regex::new(r#""homeworld"\s?:\s?"(.*?)""#).unwrap(); + let homeworld_match = re.captures(&received_string).unwrap().get(1).unwrap(); + + // Reveal everything except for the homeworld. + let start = homeworld_match.start(); + let end = homeworld_match.end(); + Idx::new([0..start, end..recv_transcript_len]) +} + +/// Redacts and reveals sent data to the verifier. +fn redact_and_reveal_sent_data(prover: &mut Prover) -> Idx { + let sent_transcript = prover.transcript().sent(); + let sent_transcript_len = sent_transcript.len(); + + let sent_string: String = String::from_utf8(sent_transcript.to_vec()).unwrap(); + let secret_start = sent_string.find(SECRET).unwrap(); + + debug!("Send data: {}", sent_string); + + // Reveal everything except for the SECRET. + Idx::new([ + 0..secret_start, + secret_start + SECRET.len()..sent_transcript_len, + ]) +} diff --git a/interactive-demo/prover-ts/.gitignore b/interactive-demo/prover-ts/.gitignore new file mode 100644 index 0000000..d8b83df --- /dev/null +++ b/interactive-demo/prover-ts/.gitignore @@ -0,0 +1 @@ +package-lock.json diff --git a/interactive-demo/prover-ts/app.tsx b/interactive-demo/prover-ts/app.tsx new file mode 100644 index 0000000..fa2a412 --- /dev/null +++ b/interactive-demo/prover-ts/app.tsx @@ -0,0 +1,140 @@ +import React, { ReactElement, useCallback, useState } from 'react'; +import { createRoot } from 'react-dom/client'; +import * as Comlink from 'comlink'; +import { Watch } from 'react-loader-spinner'; +import { Prover as TProver } from 'tlsn-js'; +import { type Method } from 'tlsn-wasm'; + +const { init, Prover }: any = Comlink.wrap( + new Worker(new URL('./worker.ts', import.meta.url)), +); + +const container = document.getElementById('root'); +const root = createRoot(container!); + +root.render(); + +function App(): ReactElement { + const [processing, setProcessing] = useState(false); + const [result, setResult] = useState(null); + + const onClick = useCallback(async () => { + setProcessing(true); + + const url = 'https://swapi.dev/api/people/1'; + const method: Method = 'GET'; + const headers = { + secret: "TLSNotary's private key", + 'Content-Type': 'application/json', + }; + const body = {}; + // let websocketProxyUrl = 'wss://notary.pse.dev/proxy'; + const websocketProxyUrl = 'ws://localhost:55688'; + const verifierProxyUrl = 'ws://localhost:9816/verify'; + const hostname = new URL(url).hostname; + + console.time('setup'); + + await init({ loggingLevel: 'Info' }); + + console.log('Setting up Prover for', hostname); + const prover = (await new Prover({ serverDns: hostname })) as TProver; + console.log('Setting up Prover: 1/2'); + await prover.setup(verifierProxyUrl); + console.log('Setting up Prover: done'); + + console.timeEnd('setup'); + + console.time('request'); + console.log('Sending request to proxy'); + const resp = await prover.sendRequest( + `${websocketProxyUrl}?token=${hostname}`, + { url, method, headers, body }, + ); + console.log('Response:', resp); + + console.log('Wait for transcript'); + const transcript = await prover.transcript(); + console.log('Transcript:', transcript); + + console.timeEnd('request'); + + console.time('reveal'); + const reveal = { + sent: [ + transcript.ranges.sent.info!, + transcript.ranges.sent.headers!['connection'], + transcript.ranges.sent.headers!['host'], + transcript.ranges.sent.headers!['content-type'], + transcript.ranges.sent.headers!['content-length'], + ...transcript.ranges.sent.lineBreaks, + ], + recv: [ + transcript.ranges.recv.info!, + transcript.ranges.recv.headers['server'], + transcript.ranges.recv.headers['date'], + transcript.ranges.recv.headers['content-type'], + transcript.ranges.recv.json!['name'], + transcript.ranges.recv.json!['eye_color'], + transcript.ranges.recv.json!['gender'], + ...transcript.ranges.recv.lineBreaks, + ], + }; + console.log('Start reveal:', reveal); + await prover.reveal(reveal); + console.timeEnd('reveal'); + + console.log('Ready'); + + console.log('Unredacted data:', { + sent: transcript.sent, + received: transcript.recv, + }); + + setResult('Unredacted data successfully revealed to Verifier.'); + + setProcessing(false); + }, [setResult, setProcessing]); + + return ( +
+

TLSNotary interactive prover demo

+
+ Before clicking the start button, make sure the{' '} + interactive verifier and the web socket proxy are running. + Check the README for the details. +
+ +
+ +
+
+ Proof: + {!processing && !result ? ( + not started yet + ) : !result ? ( + <> + Proving data from swapi... + + Open Developer tools to follow progress + + ) : ( + <> +
{JSON.stringify(result, null, 2)}
+ + )} +
+
+ ); +} diff --git a/interactive-demo/prover-ts/index.ejs b/interactive-demo/prover-ts/index.ejs new file mode 100644 index 0000000..7bd6dac --- /dev/null +++ b/interactive-demo/prover-ts/index.ejs @@ -0,0 +1,16 @@ + + + + + + + React/Typescrip Example + + + + +
+ + + \ No newline at end of file diff --git a/interactive-demo/prover-ts/package.json b/interactive-demo/prover-ts/package.json new file mode 100644 index 0000000..0a701e1 --- /dev/null +++ b/interactive-demo/prover-ts/package.json @@ -0,0 +1,30 @@ +{ + "name": "prover-ts", + "version": "1.0.0", + "description": "", + "main": "webpack.js", + "scripts": { + "dev": "webpack-dev-server --config webpack.js" + }, + "author": "", + "license": "ISC", + "dependencies": { + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-loader-spinner": "^6.1.6", + "tlsn-js": "file:../.." + }, + "devDependencies": { + "@types/react": "^18.0.26", + "@types/react-dom": "^18.0.10", + "babel-loader": "^9.1.3", + "copy-webpack-plugin": "^11.0.0", + "html-webpack-plugin": "^5.5.0", + "source-map-loader": "^5.0.0", + "ts-loader": "^9.4.2", + "typescript": "^4.9.4", + "webpack": "^5.75.0", + "webpack-cli": "^4.10.0", + "webpack-dev-server": "^4.11.1" + } +} \ No newline at end of file diff --git a/interactive-demo/prover-ts/tsconfig.json b/interactive-demo/prover-ts/tsconfig.json new file mode 100644 index 0000000..07378fc --- /dev/null +++ b/interactive-demo/prover-ts/tsconfig.json @@ -0,0 +1,26 @@ +{ + "compilerOptions": { + "target": "es5", + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "allowJs": false, + "skipLibCheck": true, + "esModuleInterop": true, + "allowSyntheticDefaultImports": true, + "strict": true, + "forceConsistentCasingInFileNames": true, + "noFallthroughCasesInSwitch": true, + "module": "esnext", + "moduleResolution": "node", + "resolveJsonModule": true, + "noEmit": false, + "jsx": "react" + }, + "include": [ + "app.tsx", + "worker.ts" + ] +} \ No newline at end of file diff --git a/interactive-demo/prover-ts/webpack.js b/interactive-demo/prover-ts/webpack.js new file mode 100644 index 0000000..bf71b21 --- /dev/null +++ b/interactive-demo/prover-ts/webpack.js @@ -0,0 +1,110 @@ +var webpack = require('webpack'), + path = require('path'), + CopyWebpackPlugin = require('copy-webpack-plugin'), + HtmlWebpackPlugin = require('html-webpack-plugin'); + +const ASSET_PATH = process.env.ASSET_PATH || '/'; + +var alias = {}; + +var fileExtensions = [ + 'jpg', + 'jpeg', + 'png', + 'gif', + 'eot', + 'otf', + 'svg', + 'ttf', + 'woff', + 'woff2', +]; + +var options = { + ignoreWarnings: [ + /Circular dependency between chunks with runtime/, + /ResizeObserver loop completed with undelivered notifications/, + ], + mode: 'development', + entry: { + app: path.join(__dirname, 'app.tsx'), + }, + output: { + filename: '[name].bundle.js', + path: path.resolve(__dirname, 'build'), + clean: true, + publicPath: ASSET_PATH, + }, + module: { + rules: [ + { + test: new RegExp('.(' + fileExtensions.join('|') + ')$'), + type: 'asset/resource', + exclude: /node_modules/, + }, + { + test: /\.html$/, + loader: 'html-loader', + exclude: /node_modules/, + }, + { + test: /\.(ts|tsx)$/, + exclude: /node_modules/, + use: [ + { + loader: require.resolve('ts-loader'), + }, + ], + }, + { + test: /\.(js|jsx)$/, + use: [ + { + loader: 'source-map-loader', + }, + { + loader: require.resolve('babel-loader'), + }, + ], + exclude: /node_modules/, + }, + ], + }, + resolve: { + alias: alias, + extensions: fileExtensions + .map((extension) => '.' + extension) + .concat(['.js', '.jsx', '.ts', '.tsx', '.css']), + }, + plugins: [ + new CopyWebpackPlugin({ + patterns: [ + { + from: 'node_modules/tlsn-js/build', + to: path.join(__dirname, 'build'), + force: true, + }, + ], + }), + new HtmlWebpackPlugin({ + template: path.join(__dirname, 'index.ejs'), + filename: 'index.html', + cache: false, + }), + new webpack.ProvidePlugin({ + Buffer: ['buffer', 'Buffer'], + }), + ].filter(Boolean), + // Required by wasm-bindgen-rayon, in order to use SharedArrayBuffer on the Web + // Ref: + // - https://github.com/GoogleChromeLabs/wasm-bindgen-rayon#setting-up + // - https://web.dev/i18n/en/coop-coep/ + devServer: { + headers: { + 'Cross-Origin-Embedder-Policy': 'require-corp', + 'Cross-Origin-Opener-Policy': 'same-origin', + }, + }, +}; + +module.exports = options; diff --git a/interactive-demo/prover-ts/worker.ts b/interactive-demo/prover-ts/worker.ts new file mode 100644 index 0000000..72ced4f --- /dev/null +++ b/interactive-demo/prover-ts/worker.ts @@ -0,0 +1,7 @@ +import * as Comlink from 'comlink'; +import init, { Prover } from 'tlsn-js'; + +Comlink.expose({ + init, + Prover, +}); diff --git a/interactive-demo/verifier-rs/Cargo.toml b/interactive-demo/verifier-rs/Cargo.toml new file mode 100644 index 0000000..19ce3db --- /dev/null +++ b/interactive-demo/verifier-rs/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "interactive-networked-verifier" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1.67" +async-tungstenite = { version = "0.25", features = ["tokio-native-tls"] } +axum = { version = "0.7", features = ["ws"] } +axum-core = "0.4" +base64 = "0.21.0" +eyre = "0.6.12" +futures-util = "0.3.28" +http = { version = "1.1" } +http-body-util = { version = "0.1" } +hyper = { version = "1.1", features = ["client", "http1", "server"] } +hyper-util = { version = "0.1", features = ["full"] } +serde = { version = "1.0.147", features = ["derive"] } +sha1 = "0.10" +tokio = {version = "1", features = [ + "rt", + "rt-multi-thread", + "macros", + "net", + "io-std", + "fs", +]} +tokio-util = { version = "0.7", features = ["compat"] } +tower = { version = "0.4.12", features = ["make"] } +tower-service = { version = "0.3" } +tracing = "0.1.40" +tracing-subscriber = { version ="0.3.18", features = ["env-filter"] } +ws_stream_tungstenite = { version = "0.13", features = ["tokio_io"] } + +tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-core" } +tlsn-verifier = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-verifier" } +tlsn-common = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-common" } +tower-util = "0.3.1" diff --git a/interactive-demo/verifier-rs/README.md b/interactive-demo/verifier-rs/README.md new file mode 100644 index 0000000..d21efac --- /dev/null +++ b/interactive-demo/verifier-rs/README.md @@ -0,0 +1,14 @@ +# verifier-server + +An implementation of the interactive verifier server in Rust. + +## Running the server +1. Configure this server setting via the global variables defined in [main.rs](./src/main.rs) — please ensure that the hardcoded `SERVER_DOMAIN` and `VERIFICATION_SESSION_ID` have the same values on the prover side. +2. Start the server by running the following in a terminal at the root of this crate. +```bash +cargo run --release +``` + +## WebSocket APIs +### /verify +To perform verification via websocket, i.e. `ws://localhost:9816/verify` diff --git a/interactive-demo/verifier-rs/src/axum_websocket.rs b/interactive-demo/verifier-rs/src/axum_websocket.rs new file mode 100644 index 0000000..4354789 --- /dev/null +++ b/interactive-demo/verifier-rs/src/axum_websocket.rs @@ -0,0 +1,930 @@ +//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.7.3/axum/src/extract/ws.rs +//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/) +//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use +//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html) +//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other modification is commented with the prefix "NOTARY_MODIFICATION:" +//! +//! The code is under the following license: +//! +//! Copyright (c) 2019 Axum Contributors +//! +//! Permission is hereby granted, free of charge, to any +//! person obtaining a copy of this software and associated +//! documentation files (the "Software"), to deal in the +//! Software without restriction, including without +//! limitation the rights to use, copy, modify, merge, +//! publish, distribute, sublicense, and/or sell copies of +//! the Software, and to permit persons to whom the Software +//! is furnished to do so, subject to the following +//! conditions: +//! +//! The above copyright notice and this permission notice +//! shall be included in all copies or substantial portions +//! of the Software. +//! +//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +//! ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +//! TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +//! PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +//! SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +//! CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +//! OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +//! IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +//! DEALINGS IN THE SOFTWARE. +//! +//! +//! Handle WebSocket connections. +//! +//! # Example +//! +//! ``` +//! use axum::{ +//! extract::ws::{WebSocketUpgrade, WebSocket}, +//! routing::get, +//! response::{IntoResponse, Response}, +//! Router, +//! }; +//! +//! let app = Router::new().route("/ws", get(handler)); +//! +//! async fn handler(ws: WebSocketUpgrade) -> Response { +//! ws.on_upgrade(handle_socket) +//! } +//! +//! async fn handle_socket(mut socket: WebSocket) { +//! while let Some(msg) = socket.recv().await { +//! let msg = if let Ok(msg) = msg { +//! msg +//! } else { +//! // client disconnected +//! return; +//! }; +//! +//! if socket.send(msg).await.is_err() { +//! // client disconnected +//! return; +//! } +//! } +//! } +//! # let _: Router = app; +//! ``` +//! +//! # Passing data and/or state to an `on_upgrade` callback +//! +//! ``` +//! use axum::{ +//! extract::{ws::{WebSocketUpgrade, WebSocket}, State}, +//! response::Response, +//! routing::get, +//! Router, +//! }; +//! +//! #[derive(Clone)] +//! struct AppState { +//! // ... +//! } +//! +//! async fn handler(ws: WebSocketUpgrade, State(state): State) -> Response { +//! ws.on_upgrade(|socket| handle_socket(socket, state)) +//! } +//! +//! async fn handle_socket(socket: WebSocket, state: AppState) { +//! // ... +//! } +//! +//! let app = Router::new() +//! .route("/ws", get(handler)) +//! .with_state(AppState { /* ... */ }); +//! # let _: Router = app; +//! ``` +//! +//! # Read and write concurrently +//! +//! If you need to read and write concurrently from a [`WebSocket`] you can use +//! [`StreamExt::split`]: +//! +//! ```rust,no_run +//! use axum::{Error, extract::ws::{WebSocket, Message}}; +//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}}; +//! +//! async fn handle_socket(mut socket: WebSocket) { +//! let (mut sender, mut receiver) = socket.split(); +//! +//! tokio::spawn(write(sender)); +//! tokio::spawn(read(receiver)); +//! } +//! +//! async fn read(receiver: SplitStream) { +//! // ... +//! } +//! +//! async fn write(sender: SplitSink) { +//! // ... +//! } +//! ``` +//! +//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split +#![allow(unused)] + +use self::rejection::*; +use async_trait::async_trait; +use async_tungstenite::{ + tokio::TokioAdapter, + tungstenite::{ + self as ts, + protocol::{self, WebSocketConfig}, + }, + WebSocketStream, +}; +use axum::{body::Bytes, extract::FromRequestParts, response::Response, Error}; +use axum_core::body::Body; +use futures_util::{ + sink::{Sink, SinkExt}, + stream::{Stream, StreamExt}, +}; +use http::{ + header::{self, HeaderMap, HeaderName, HeaderValue}, + request::Parts, + Method, StatusCode, +}; +use hyper_util::rt::TokioIo; +use sha1::{Digest, Sha1}; +use std::{ + borrow::Cow, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tracing::error; + +/// Extractor for establishing WebSocket connections. +/// +/// Note: This extractor requires the request method to be `GET` so it should +/// always be used with [`get`](crate::routing::get). Requests with other methods will be +/// rejected. +/// +/// See the [module docs](self) for an example. +#[cfg_attr(docsrs, doc(cfg(feature = "ws")))] +pub struct WebSocketUpgrade { + config: WebSocketConfig, + /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. + protocol: Option, + sec_websocket_key: HeaderValue, + on_upgrade: hyper::upgrade::OnUpgrade, + on_failed_upgrade: F, + sec_websocket_protocol: Option, +} + +impl std::fmt::Debug for WebSocketUpgrade { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WebSocketUpgrade") + .field("config", &self.config) + .field("protocol", &self.protocol) + .field("sec_websocket_key", &self.sec_websocket_key) + .field("sec_websocket_protocol", &self.sec_websocket_protocol) + .finish_non_exhaustive() + } +} + +impl WebSocketUpgrade { + /// The target minimum size of the write buffer to reach before writing the data + /// to the underlying stream. + /// + /// The default value is 128 KiB. + /// + /// If set to `0` each message will be eagerly written to the underlying stream. + /// It is often more optimal to allow them to buffer a little, hence the default value. + /// + /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless. + pub fn write_buffer_size(mut self, size: usize) -> Self { + self.config.write_buffer_size = size; + self + } + + /// The max size of the write buffer in bytes. Setting this can provide backpressure + /// in the case the write buffer is filling up due to write errors. + /// + /// The default value is unlimited. + /// + /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size) + /// when writes to the underlying stream are failing. So the **write buffer can not + /// fill up if you are not observing write errors even if not flushing**. + /// + /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size) + /// and probably a little more depending on error handling strategy. + pub fn max_write_buffer_size(mut self, max: usize) -> Self { + self.config.max_write_buffer_size = max; + self + } + + /// Set the maximum message size (defaults to 64 megabytes) + pub fn max_message_size(mut self, max: usize) -> Self { + self.config.max_message_size = Some(max); + self + } + + /// Set the maximum frame size (defaults to 16 megabytes) + pub fn max_frame_size(mut self, max: usize) -> Self { + self.config.max_frame_size = Some(max); + self + } + + /// Allow server to accept unmasked frames (defaults to false) + pub fn accept_unmasked_frames(mut self, accept: bool) -> Self { + self.config.accept_unmasked_frames = accept; + self + } + + /// Set the known protocols. + /// + /// If the protocol name specified by `Sec-WebSocket-Protocol` header + /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and + /// return the protocol name. + /// + /// The protocols should be listed in decreasing order of preference: if the client offers + /// multiple protocols that the server could support, the server will pick the first one in + /// this list. + /// + /// # Examples + /// + /// ``` + /// use axum::{ + /// extract::ws::{WebSocketUpgrade, WebSocket}, + /// routing::get, + /// response::{IntoResponse, Response}, + /// Router, + /// }; + /// + /// let app = Router::new().route("/ws", get(handler)); + /// + /// async fn handler(ws: WebSocketUpgrade) -> Response { + /// ws.protocols(["graphql-ws", "graphql-transport-ws"]) + /// .on_upgrade(|socket| async { + /// // ... + /// }) + /// } + /// # let _: Router = app; + /// ``` + pub fn protocols(mut self, protocols: I) -> Self + where + I: IntoIterator, + I::Item: Into>, + { + if let Some(req_protocols) = self + .sec_websocket_protocol + .as_ref() + .and_then(|p| p.to_str().ok()) + { + self.protocol = protocols + .into_iter() + // FIXME: This will often allocate a new `String` and so is less efficient than it + // could be. But that can't be fixed without breaking changes to the public API. + .map(Into::into) + .find(|protocol| { + req_protocols + .split(',') + .any(|req_protocol| req_protocol.trim() == protocol) + }) + .map(|protocol| match protocol { + Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), + Cow::Borrowed(s) => HeaderValue::from_static(s), + }); + } + + self + } + + /// Provide a callback to call if upgrading the connection fails. + /// + /// The connection upgrade is performed in a background task. If that fails this callback + /// will be called. + /// + /// By default any errors will be silently ignored. + /// + /// # Example + /// + /// ``` + /// use axum::{ + /// extract::{WebSocketUpgrade}, + /// response::Response, + /// }; + /// + /// async fn handler(ws: WebSocketUpgrade) -> Response { + /// ws.on_failed_upgrade(|error| { + /// report_error(error); + /// }) + /// .on_upgrade(|socket| async { /* ... */ }) + /// } + /// # + /// # fn report_error(_: axum::Error) {} + /// ``` + pub fn on_failed_upgrade(self, callback: C) -> WebSocketUpgrade + where + C: OnFailedUpgrade, + { + WebSocketUpgrade { + config: self.config, + protocol: self.protocol, + sec_websocket_key: self.sec_websocket_key, + on_upgrade: self.on_upgrade, + on_failed_upgrade: callback, + sec_websocket_protocol: self.sec_websocket_protocol, + } + } + + /// Finalize upgrading the connection and call the provided callback with + /// the stream. + #[must_use = "to set up the WebSocket connection, this response must be returned"] + pub fn on_upgrade(self, callback: C) -> Response + where + C: FnOnce(WebSocket) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + F: OnFailedUpgrade, + { + let on_upgrade = self.on_upgrade; + let config = self.config; + let on_failed_upgrade = self.on_failed_upgrade; + + let protocol = self.protocol.clone(); + + tokio::spawn(async move { + let upgraded = match on_upgrade.await { + Ok(upgraded) => upgraded, + Err(err) => { + error!("Something wrong with on_upgrade: {:?}", err); + on_failed_upgrade.call(Error::new(err)); + return; + } + }; + let upgraded = TokioIo::new(upgraded); + + let socket = WebSocketStream::from_raw_socket( + // NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite + TokioAdapter::new(upgraded), + protocol::Role::Server, + Some(config), + ) + .await; + let socket = WebSocket { + inner: socket, + protocol, + }; + callback(socket).await; + }); + + #[allow(clippy::declare_interior_mutable_const)] + const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + #[allow(clippy::declare_interior_mutable_const)] + const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + + let mut builder = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, UPGRADE) + .header(header::UPGRADE, WEBSOCKET) + .header( + header::SEC_WEBSOCKET_ACCEPT, + sign(self.sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = self.protocol { + builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + } + + builder.body(Body::empty()).unwrap() + } +} + +/// What to do when a connection upgrade fails. +/// +/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. +pub trait OnFailedUpgrade: Send + 'static { + /// Call the callback. + fn call(self, error: Error); +} + +impl OnFailedUpgrade for F +where + F: FnOnce(Error) + Send + 'static, +{ + fn call(self, error: Error) { + self(error) + } +} + +/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`. +/// +/// It simply ignores the error. +#[non_exhaustive] +#[derive(Debug)] +pub struct DefaultOnFailedUpgrade; + +impl OnFailedUpgrade for DefaultOnFailedUpgrade { + #[inline] + fn call(self, _error: Error) {} +} + +#[async_trait] +impl FromRequestParts for WebSocketUpgrade +where + S: Send + Sync, +{ + type Rejection = WebSocketUpgradeRejection; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + if parts.method != Method::GET { + return Err(MethodNotGet.into()); + } + + if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { + return Err(InvalidConnectionHeader.into()); + } + + if !header_eq(&parts.headers, header::UPGRADE, "websocket") { + return Err(InvalidUpgradeHeader.into()); + } + + if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { + return Err(InvalidWebSocketVersionHeader.into()); + } + + let sec_websocket_key = parts + .headers + .get(header::SEC_WEBSOCKET_KEY) + .ok_or(WebSocketKeyHeaderMissing)? + .clone(); + + let on_upgrade = parts + .extensions + .remove::() + .ok_or(ConnectionNotUpgradable)?; + + let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); + + Ok(Self { + config: Default::default(), + protocol: None, + sec_websocket_key, + on_upgrade, + sec_websocket_protocol, + on_failed_upgrade: DefaultOnFailedUpgrade, + }) + } +} + +/// NOTARY_MODIFICATION: Made this function public to be used in service.rs +pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + if let Some(header) = headers.get(&key) { + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) + } else { + false + } +} + +fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + let header = if let Some(header) = headers.get(&key) { + header + } else { + return false; + }; + + if let Ok(header) = std::str::from_utf8(header.as_bytes()) { + header.to_ascii_lowercase().contains(value) + } else { + false + } +} + +/// A stream of WebSocket messages. +/// +/// See [the module level documentation](self) for more details. +#[derive(Debug)] +pub struct WebSocket { + inner: WebSocketStream>>, + protocol: Option, +} + +impl WebSocket { + /// NOTARY_MODIFICATION: Consume `self` and get the inner [`async_tungstenite::WebSocketStream`]. + pub fn into_inner(self) -> WebSocketStream>> { + self.inner + } + + /// Receive another message. + /// + /// Returns `None` if the stream has closed. + pub async fn recv(&mut self) -> Option> { + self.next().await + } + + /// Send a message. + pub async fn send(&mut self, msg: Message) -> Result<(), Error> { + self.inner + .send(msg.into_tungstenite()) + .await + .map_err(Error::new) + } + + /// Gracefully close this WebSocket. + pub async fn close(mut self) -> Result<(), Error> { + self.inner.close(None).await.map_err(Error::new) + } + + /// Return the selected WebSocket subprotocol, if one has been chosen. + pub fn protocol(&self) -> Option<&HeaderValue> { + self.protocol.as_ref() + } +} + +impl Stream for WebSocket { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match futures_util::ready!(self.inner.poll_next_unpin(cx)) { + Some(Ok(msg)) => { + if let Some(msg) = Message::from_tungstenite(msg) { + return Poll::Ready(Some(Ok(msg))); + } + } + Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))), + None => return Poll::Ready(None), + } + } + } +} + +impl Sink for WebSocket { + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new) + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + Pin::new(&mut self.inner) + .start_send(item.into_tungstenite()) + .map_err(Error::new) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new) + } +} + +/// Status code used to indicate why an endpoint is closing the WebSocket connection. +pub type CloseCode = u16; + +/// A struct representing the close command. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct CloseFrame<'t> { + /// The reason as a code. + pub code: CloseCode, + /// The reason as text string. + pub reason: Cow<'t, str>, +} + +/// A WebSocket message. +// +// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license: +// Copyright (c) 2017 Alexey Galakhov +// Copyright (c) 2016 Jason Housley +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Message { + /// A text WebSocket message + Text(String), + /// A binary WebSocket message + Binary(Vec), + /// A ping message with the specified payload + /// + /// The payload here must have a length less than 125 bytes. + /// + /// Ping messages will be automatically responded to by the server, so you do not have to worry + /// about dealing with them yourself. + Ping(Vec), + /// A pong message with the specified payload + /// + /// The payload here must have a length less than 125 bytes. + /// + /// Pong messages will be automatically sent to the client if a ping message is received, so + /// you do not have to worry about constructing them yourself unless you want to implement a + /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3). + Pong(Vec), + /// A close message with the optional close frame. + Close(Option>), +} + +impl Message { + fn into_tungstenite(self) -> ts::Message { + match self { + Self::Text(text) => ts::Message::Text(text), + Self::Binary(binary) => ts::Message::Binary(binary), + Self::Ping(ping) => ts::Message::Ping(ping), + Self::Pong(pong) => ts::Message::Pong(pong), + Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame { + code: ts::protocol::frame::coding::CloseCode::from(close.code), + reason: close.reason, + })), + Self::Close(None) => ts::Message::Close(None), + } + } + + fn from_tungstenite(message: ts::Message) -> Option { + match message { + ts::Message::Text(text) => Some(Self::Text(text)), + ts::Message::Binary(binary) => Some(Self::Binary(binary)), + ts::Message::Ping(ping) => Some(Self::Ping(ping)), + ts::Message::Pong(pong) => Some(Self::Pong(pong)), + ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame { + code: close.code.into(), + reason: close.reason, + }))), + ts::Message::Close(None) => Some(Self::Close(None)), + // we can ignore `Frame` frames as recommended by the tungstenite maintainers + // https://github.com/snapview/tungstenite-rs/issues/268 + ts::Message::Frame(_) => None, + } + } + + /// Consume the WebSocket and return it as binary data. + pub fn into_data(self) -> Vec { + match self { + Self::Text(string) => string.into_bytes(), + Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data, + Self::Close(None) => Vec::new(), + Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), + } + } + + /// Attempt to consume the WebSocket message and convert it to a String. + pub fn into_text(self) -> Result { + match self { + Self::Text(string) => Ok(string), + Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data) + .map_err(|err| err.utf8_error()) + .map_err(Error::new)?), + Self::Close(None) => Ok(String::new()), + Self::Close(Some(frame)) => Ok(frame.reason.into_owned()), + } + } + + /// Attempt to get a &str from the WebSocket message, + /// this will try to convert binary data to utf8. + pub fn to_text(&self) -> Result<&str, Error> { + match *self { + Self::Text(ref string) => Ok(string), + Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => { + Ok(std::str::from_utf8(data).map_err(Error::new)?) + } + Self::Close(None) => Ok(""), + Self::Close(Some(ref frame)) => Ok(&frame.reason), + } + } +} + +impl From for Message { + fn from(string: String) -> Self { + Message::Text(string) + } +} + +impl<'s> From<&'s str> for Message { + fn from(string: &'s str) -> Self { + Message::Text(string.into()) + } +} + +impl<'b> From<&'b [u8]> for Message { + fn from(data: &'b [u8]) -> Self { + Message::Binary(data.into()) + } +} + +impl From> for Message { + fn from(data: Vec) -> Self { + Message::Binary(data) + } +} + +impl From for Vec { + fn from(msg: Message) -> Self { + msg.into_data() + } +} + +fn sign(key: &[u8]) -> HeaderValue { + use base64::engine::Engine as _; + + let mut sha1 = Sha1::default(); + sha1.update(key); + sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]); + let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize())); + HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value") +} + +pub mod rejection { + //! WebSocket specific rejections. + + use axum_core::{ + __composite_rejection as composite_rejection, __define_rejection as define_rejection, + }; + + define_rejection! { + #[status = METHOD_NOT_ALLOWED] + #[body = "Request method must be `GET`"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct MethodNotGet; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "Connection header did not include 'upgrade'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidConnectionHeader; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`Upgrade` header did not include 'websocket'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidUpgradeHeader; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`Sec-WebSocket-Version` header did not include '13'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidWebSocketVersionHeader; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`Sec-WebSocket-Key` header missing"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct WebSocketKeyHeaderMissing; + } + + define_rejection! { + #[status = UPGRADE_REQUIRED] + #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + /// + /// This rejection is returned if the connection cannot be upgraded for example if the + /// request is HTTP/1.0. + /// + /// See [MDN] for more details about connection upgrades. + /// + /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade + pub struct ConnectionNotUpgradable; + } + + composite_rejection! { + /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade). + /// + /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade) + /// extractor can fail. + pub enum WebSocketUpgradeRejection { + MethodNotGet, + InvalidConnectionHeader, + InvalidUpgradeHeader, + InvalidWebSocketVersionHeader, + WebSocketKeyHeaderMissing, + ConnectionNotUpgradable, + } + } +} + +pub mod close_code { + //! Constants for [`CloseCode`]s. + //! + //! [`CloseCode`]: super::CloseCode + + /// Indicates a normal closure, meaning that the purpose for which the connection was + /// established has been fulfilled. + pub const NORMAL: u16 = 1000; + + /// Indicates that an endpoint is "going away", such as a server going down or a browser having + /// navigated away from a page. + pub const AWAY: u16 = 1001; + + /// Indicates that an endpoint is terminating the connection due to a protocol error. + pub const PROTOCOL: u16 = 1002; + + /// Indicates that an endpoint is terminating the connection because it has received a type of + /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if + /// it receives a binary message). + pub const UNSUPPORTED: u16 = 1003; + + /// Indicates that no status code was included in a closing frame. + pub const STATUS: u16 = 1005; + + /// Indicates an abnormal closure. + pub const ABNORMAL: u16 = 1006; + + /// Indicates that an endpoint is terminating the connection because it has received data + /// within a message that was not consistent with the type of the message (e.g., non-UTF-8 + /// RFC3629 data within a text message). + pub const INVALID: u16 = 1007; + + /// Indicates that an endpoint is terminating the connection because it has received a message + /// that violates its policy. This is a generic status code that can be returned when there is + /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to + /// hide specific details about the policy. + pub const POLICY: u16 = 1008; + + /// Indicates that an endpoint is terminating the connection because it has received a message + /// that is too big for it to process. + pub const SIZE: u16 = 1009; + + /// Indicates that an endpoint (client) is terminating the connection because it has expected + /// the server to negotiate one or more extension, but the server didn't return them in the + /// response message of the WebSocket handshake. The list of extensions that are needed should + /// be given as the reason for closing. Note that this status code is not used by the server, + /// because it can fail the WebSocket handshake instead. + pub const EXTENSION: u16 = 1010; + + /// Indicates that a server is terminating the connection because it encountered an unexpected + /// condition that prevented it from fulfilling the request. + pub const ERROR: u16 = 1011; + + /// Indicates that the server is restarting. + pub const RESTART: u16 = 1012; + + /// Indicates that the server is overloaded and the client should either connect to a different + /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an + /// action. + pub const AGAIN: u16 = 1013; +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::{body::Body, routing::get, Router}; + use http::{Request, Version}; + // NOTARY_MODIFICATION: use tower_util instead of tower to make clippy happy + use tower_util::ServiceExt; + + #[tokio::test] + async fn rejects_http_1_0_requests() { + let svc = get(|ws: Result| { + let rejection = ws.unwrap_err(); + assert!(matches!( + rejection, + WebSocketUpgradeRejection::ConnectionNotUpgradable(_) + )); + std::future::ready(()) + }); + + let req = Request::builder() + .version(Version::HTTP_10) + .method(Method::GET) + .header("upgrade", "websocket") + .header("connection", "Upgrade") + .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==") + .header("sec-websocket-version", "13") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[allow(dead_code)] + fn default_on_failed_upgrade() { + async fn handler(ws: WebSocketUpgrade) -> Response { + ws.on_upgrade(|_| async {}) + } + let _: Router = Router::new().route("/", get(handler)); + } + + #[allow(dead_code)] + fn on_failed_upgrade() { + async fn handler(ws: WebSocketUpgrade) -> Response { + ws.on_failed_upgrade(|_error: Error| println!("oops!")) + .on_upgrade(|_| async {}) + } + let _: Router = Router::new().route("/", get(handler)); + } +} diff --git a/interactive-demo/verifier-rs/src/lib.rs b/interactive-demo/verifier-rs/src/lib.rs new file mode 100644 index 0000000..04acb1e --- /dev/null +++ b/interactive-demo/verifier-rs/src/lib.rs @@ -0,0 +1,176 @@ +use axum::{ + extract::{Request, State}, + response::IntoResponse, + routing::get, + Router, +}; +use axum_websocket::{WebSocket, WebSocketUpgrade}; +use eyre::eyre; +use hyper::{body::Incoming, server::conn::http1}; +use hyper_util::rt::TokioIo; +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; +use tlsn_common::config::ProtocolConfigValidator; +use tlsn_verifier::{SessionInfo, Verifier, VerifierConfig}; + +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpListener, +}; +use tokio_util::compat::TokioAsyncReadCompatExt; +use tower_service::Service; +use tracing::{debug, error, info}; +use ws_stream_tungstenite::WsStream; + +mod axum_websocket; + +// Maximum number of bytes that can be sent from prover to server +const MAX_SENT_DATA: usize = 1 << 12; +// Maximum number of bytes that can be received by prover from server +const MAX_RECV_DATA: usize = 1 << 14; + +/// Global data that needs to be shared with the axum handlers +#[derive(Clone, Debug)] +struct VerifierGlobals { + pub server_domain: String, +} + +pub async fn run_server( + verifier_host: &str, + verifier_port: u16, + server_domain: &str, +) -> Result<(), eyre::ErrReport> { + let verifier_address = SocketAddr::new( + IpAddr::V4(verifier_host.parse().map_err(|err| { + eyre!("Failed to parse verifer host address from server config: {err}") + })?), + verifier_port, + ); + let listener = TcpListener::bind(verifier_address) + .await + .map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?; + + info!("Listening for TCP traffic at {}", verifier_address); + + let protocol = Arc::new(http1::Builder::new()); + let router = Router::new() + .route("/verify", get(ws_handler)) + .with_state(VerifierGlobals { + server_domain: server_domain.to_string(), + }); + + loop { + let stream = match listener.accept().await { + Ok((stream, _)) => stream, + Err(err) => { + error!("Failed to connect to prover: {err}"); + continue; + } + }; + debug!("Received a prover's TCP connection"); + + let tower_service = router.clone(); + let protocol = protocol.clone(); + + tokio::spawn(async move { + info!("Accepted prover's TCP connection",); + // Reference: https://github.com/tokio-rs/axum/blob/5201798d4e4d4759c208ef83e30ce85820c07baa/examples/low-level-rustls/src/main.rs#L67-L80 + let io = TokioIo::new(stream); + let hyper_service = hyper::service::service_fn(move |request: Request| { + tower_service.clone().call(request) + }); + // Serve different requests using the same hyper protocol and axum router + let _ = protocol + .serve_connection(io, hyper_service) + // use with_upgrades to upgrade connection to websocket for websocket clients + // and to extract tcp connection for tcp clients + .with_upgrades() + .await; + }); + } +} + +async fn ws_handler( + ws: WebSocketUpgrade, + State(verifier_globals): State, +) -> impl IntoResponse { + info!("Received websocket request"); + ws.on_upgrade(|socket| handle_socket(socket, verifier_globals)) +} + +async fn handle_socket(socket: WebSocket, verifier_globals: VerifierGlobals) { + debug!("Upgraded to websocket connection"); + let stream = WsStream::new(socket.into_inner()); + + match verifier(stream, &verifier_globals.server_domain).await { + Ok((sent, received, _session_info)) => { + info!("Successfully verified {}", &verifier_globals.server_domain); + info!("Verified sent data:\n{}", sent,); + println!("Verified received data:\n{}", received,); + } + Err(err) => { + error!("Failed verification using websocket: {err}"); + } + } +} + +async fn verifier( + socket: T, + server_domain: &str, +) -> Result<(String, String, SessionInfo), eyre::ErrReport> { + debug!("Starting verification..."); + + // Setup Verifier. + let config_validator = ProtocolConfigValidator::builder() + .max_sent_data(MAX_SENT_DATA) + .max_recv_data(MAX_RECV_DATA) + .build() + .unwrap(); + + let verifier_config = VerifierConfig::builder() + .protocol_config_validator(config_validator) + .build() + .unwrap(); + let verifier = Verifier::new(verifier_config); + + // Verify MPC-TLS and wait for (redacted) data. + debug!("Starting MPC-TLS verification..."); + // Verify MPC-TLS and wait for (redacted) data. + let (mut partial_transcript, session_info) = verifier.verify(socket.compat()).await.unwrap(); + partial_transcript.set_unauthed(0); + + // Check sent data: check host. + debug!("Starting sent data verification..."); + let sent = partial_transcript.sent_unsafe().to_vec(); + let sent_data = String::from_utf8(sent.clone()).expect("Verifier expected sent data"); + sent_data + .find(server_domain) + .ok_or_else(|| eyre!("Verification failed: Expected host {}", server_domain))?; + + // Check received data: check json and version number. + debug!("Starting received data verification..."); + let received = partial_transcript.received_unsafe().to_vec(); + let response = String::from_utf8(received.clone()).expect("Verifier expected received data"); + debug!("Received data: {:?}", response); + response + .find("eye_color") + .ok_or_else(|| eyre!("Verification failed: missing eye_color in received data"))?; + // Check Session info: server name. + if session_info.server_name.as_str() != server_domain { + return Err(eyre!("Verification failed: server name mismatches")); + } + + let sent_string = bytes_to_redacted_string(&sent)?; + let received_string = bytes_to_redacted_string(&received)?; + + Ok((sent_string, received_string, session_info)) +} + +/// Render redacted bytes as `🙈`. +fn bytes_to_redacted_string(bytes: &[u8]) -> Result { + Ok(String::from_utf8(bytes.to_vec()) + .map_err(|err| eyre!("Failed to parse bytes to redacted string: {err}"))? + .replace('\0', "🙈")) +} diff --git a/interactive-demo/verifier-rs/src/main.rs b/interactive-demo/verifier-rs/src/main.rs new file mode 100644 index 0000000..b9455f2 --- /dev/null +++ b/interactive-demo/verifier-rs/src/main.rs @@ -0,0 +1,22 @@ +use interactive_networked_verifier::run_server; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +const TRACING_FILTER: &str = "INFO"; + +const VERIFIER_HOST: &str = "0.0.0.0"; +const VERIFIER_PORT: u16 = 9816; + +/// Make sure the following domain is the same in SERVER_URL on the prover side +const SERVER_DOMAIN: &str = "swapi.dev"; + +#[tokio::main] +async fn main() -> Result<(), eyre::ErrReport> { + tracing_subscriber::registry() + .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| TRACING_FILTER.into())) + .with(tracing_subscriber::fmt::layer()) + .init(); + + run_server(VERIFIER_HOST, VERIFIER_PORT, SERVER_DOMAIN).await?; + + Ok(()) +} diff --git a/package.json b/package.json index c7dc242..1f2e65f 100644 --- a/package.json +++ b/package.json @@ -8,7 +8,6 @@ "files": [ "build/", "src/", - "wasm/pkg/*", "readme.md" ], "scripts": { diff --git a/src/lib.ts b/src/lib.ts index eb98dda..33d5e1a 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -56,9 +56,9 @@ export default async function init(config?: { }); // 6422528 ~= 6.12 mb - debug('res.memory=', res.memory); - debug('res.memory.buffer.length=', res.memory.buffer.byteLength); - debug('DEBUG', 'initialize thread pool'); + debug('res.memory', res.memory); + debug('res.memory.buffer.length', res.memory.buffer.byteLength); + debug('initialize thread pool'); await initThreadPool(hardwareConcurrency); debug('initialized thread pool'); @@ -273,6 +273,7 @@ export class Prover { async reveal(reveal: Reveal) { return this.#prover.reveal(reveal); } + } export class Verifier {