Interactive Prover in Typescript (Verifier in Rust) (#74)

feat: interactive verifier demo

* Verifier in Rust
* Prover in both Rust and TypeScript/React

---------
Co-authored-by: Pete Thomas <pete@xminusone.net>
Co-authored-by: yuroitaki <25913766+yuroitaki@users.noreply.github.com>
This commit is contained in:
Hendrik Eeckhaut
2024-11-05 22:41:30 +07:00
committed by GitHub
parent 18a30d32c1
commit 1ded4136bf
20 changed files with 1771 additions and 4 deletions

6
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,6 @@
{
"rust-analyzer.linkedProjects": [
"interactive-demo/verifier-rs/Cargo.toml",
"interactive-demo/prover-rs/Cargo.toml"
],
}

2
interactive-demo/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
**/target/
**/Cargo.lock

View File

@@ -0,0 +1,38 @@
# Test Rust Prover
1. Start the verifier:
```bash
cd verifier-rs; cargo run --release
```
2. Run the prover:
```bash
cd prover-rs; cargo run --release
```
# Test Browser Prover
1. Start the verifier:
```bash
cd verifier-rs; cargo run --release
```
2. Since a web browser doesn't have the ability to make TCP connection, we need to use a websocket proxy server to access <swapi.dev>.
```bash
cargo install wstcp
wstcp --bind-addr 127.0.0.1:55688 swapi.dev:443
```
3. Run the prover
1. Build tlsn-js
```bash
cd ..
npm i
npm run build
npm link
```
2. Build demo prover-ts
```bash
cd prover-ts
npm i
npm link
npm run dev
```
3. Open <http://localhost:8080/> and click **Start Prover**

View File

@@ -0,0 +1,30 @@
[package]
name = "interactive-networked-prover"
version = "0.1.0"
edition = "2021"
[dependencies]
async-tungstenite = { version = "0.25", features = ["tokio-runtime"] }
futures = "0.3"
http = "1.1"
http-body-util = "0.1"
hyper = {version = "1.1", features = ["client", "http1"]}
hyper-util = {version = "0.1", features = ["full"]}
regex = "1.10.3"
tokio = {version = "1", features = [
"rt",
"rt-multi-thread",
"macros",
"net",
"io-std",
"fs",
]}
tokio-util = { version = "0.7", features = ["compat"] }
tracing = "0.1.40"
tracing-subscriber = { version ="0.3.18", features = ["env-filter"] }
uuid = { version = "1.4.1", features = ["v4", "fast-rng"] }
ws_stream_tungstenite = { version = "0.13", features = ["tokio_io"] }
tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-core" }
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-prover" }
tlsn-common = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-common" }

View File

@@ -0,0 +1,9 @@
## Interactive Prover
An implementation of the interactive prover in Rust.
## Running the prover
1. Configure this prover setting via the global variables defined in [main.rs](./src/main.rs) — please ensure that the hardcoded `SERVER_URL` and `VERIFICATION_SESSION_ID` have the same values on the verifier side.
2. Start the prover by running the following in a terminal at the root of this crate.
```bash
cargo run --release
```

View File

@@ -0,0 +1,172 @@
use async_tungstenite::{tokio::connect_async_with_config, tungstenite::protocol::WebSocketConfig};
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
use regex::Regex;
use tlsn_common::config::ProtocolConfig;
use tlsn_core::transcript::Idx;
use tlsn_prover::{state::Prove, Prover, ProverConfig};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::{debug, info};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use ws_stream_tungstenite::WsStream;
const TRACING_FILTER: &str = "INFO";
const VERIFIER_HOST: &str = "localhost";
const VERIFIER_PORT: u16 = 9816;
// Maximum number of bytes that can be sent from prover to server
const MAX_SENT_DATA: usize = 1 << 12;
// Maximum number of bytes that can be received by prover from server
const MAX_RECV_DATA: usize = 1 << 14;
const SECRET: &str = "TLSNotary's private key 🤡";
/// Make sure the following url's domain is the same as SERVER_DOMAIN on the verifier side
const SERVER_URL: &str = "https://swapi.dev/api/people/1";
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| TRACING_FILTER.into()))
.with(tracing_subscriber::fmt::layer())
.init();
run_prover(VERIFIER_HOST, VERIFIER_PORT, SERVER_URL).await;
}
async fn run_prover(verifier_host: &str, verifier_port: u16, server_uri: &str) {
info!("Sending websocket request...");
let request = http::Request::builder()
.uri(format!("ws://{}:{}/verify", verifier_host, verifier_port,))
.header("Host", verifier_host)
.header("Sec-WebSocket-Key", uuid::Uuid::new_v4().to_string())
.header("Sec-WebSocket-Version", "13")
.header("Connection", "Upgrade")
.header("Upgrade", "Websocket")
.body(())
.unwrap();
let (verifier_ws_stream, _) =
connect_async_with_config(request, Some(WebSocketConfig::default()))
.await
.unwrap();
info!("Websocket connection established!");
let verifier_ws_socket = WsStream::new(verifier_ws_stream);
prover(verifier_ws_socket, server_uri).await;
info!("Proving is successful!");
}
async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_socket: T, uri: &str) {
debug!("Starting proving...");
let uri = uri.parse::<Uri>().unwrap();
assert_eq!(uri.scheme().unwrap().as_str(), "https");
let server_domain = uri.authority().unwrap().host();
let server_port = uri.port_u16().unwrap_or(443);
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(
ProverConfig::builder()
.server_name(server_domain)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap(),
)
.build()
.unwrap(),
)
.setup(verifier_socket.compat())
.await
.unwrap();
// Connect to TLS Server.
let tls_client_socket = tokio::net::TcpStream::connect((server_domain, server_port))
.await
.unwrap();
// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) =
prover.connect(tls_client_socket.compat()).await.unwrap();
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection)
.await
.unwrap();
tokio::spawn(connection);
// MPC-TLS: Send Request and wait for Response.
info!("Send Request and wait for Response");
let request = Request::builder()
.uri(uri.clone())
.header("Host", server_domain)
.header("Connection", "close")
.header("Secret", SECRET)
.method("GET")
.body(Empty::<Bytes>::new())
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
debug!("TLS response: {:?}", response);
assert!(response.status() == StatusCode::OK);
// Create proof for the Verifier.
let mut prover = prover_task.await.unwrap().unwrap().start_prove();
let idx_sent = redact_and_reveal_sent_data(&mut prover);
let idx_recv = redact_and_reveal_received_data(&mut prover);
// Reveal parts of the transcript
prover.prove_transcript(idx_sent, idx_recv).await.unwrap();
// Finalize.
prover.finalize().await.unwrap()
}
/// Redacts and reveals received data to the verifier.
fn redact_and_reveal_received_data(prover: &mut Prover<Prove>) -> Idx {
let recv_transcript = prover.transcript().received();
let recv_transcript_len = recv_transcript.len();
// Get the homeworld from the received data.
let received_string = String::from_utf8(recv_transcript.to_vec()).unwrap();
debug!("Received data: {}", received_string);
let re = Regex::new(r#""homeworld"\s?:\s?"(.*?)""#).unwrap();
let homeworld_match = re.captures(&received_string).unwrap().get(1).unwrap();
// Reveal everything except for the homeworld.
let start = homeworld_match.start();
let end = homeworld_match.end();
Idx::new([0..start, end..recv_transcript_len])
}
/// Redacts and reveals sent data to the verifier.
fn redact_and_reveal_sent_data(prover: &mut Prover<Prove>) -> Idx {
let sent_transcript = prover.transcript().sent();
let sent_transcript_len = sent_transcript.len();
let sent_string: String = String::from_utf8(sent_transcript.to_vec()).unwrap();
let secret_start = sent_string.find(SECRET).unwrap();
debug!("Send data: {}", sent_string);
// Reveal everything except for the SECRET.
Idx::new([
0..secret_start,
secret_start + SECRET.len()..sent_transcript_len,
])
}

1
interactive-demo/prover-ts/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
package-lock.json

View File

@@ -0,0 +1,140 @@
import React, { ReactElement, useCallback, useState } from 'react';
import { createRoot } from 'react-dom/client';
import * as Comlink from 'comlink';
import { Watch } from 'react-loader-spinner';
import { Prover as TProver } from 'tlsn-js';
import { type Method } from 'tlsn-wasm';
const { init, Prover }: any = Comlink.wrap(
new Worker(new URL('./worker.ts', import.meta.url)),
);
const container = document.getElementById('root');
const root = createRoot(container!);
root.render(<App />);
function App(): ReactElement {
const [processing, setProcessing] = useState(false);
const [result, setResult] = useState<string | null>(null);
const onClick = useCallback(async () => {
setProcessing(true);
const url = 'https://swapi.dev/api/people/1';
const method: Method = 'GET';
const headers = {
secret: "TLSNotary's private key",
'Content-Type': 'application/json',
};
const body = {};
// let websocketProxyUrl = 'wss://notary.pse.dev/proxy';
const websocketProxyUrl = 'ws://localhost:55688';
const verifierProxyUrl = 'ws://localhost:9816/verify';
const hostname = new URL(url).hostname;
console.time('setup');
await init({ loggingLevel: 'Info' });
console.log('Setting up Prover for', hostname);
const prover = (await new Prover({ serverDns: hostname })) as TProver;
console.log('Setting up Prover: 1/2');
await prover.setup(verifierProxyUrl);
console.log('Setting up Prover: done');
console.timeEnd('setup');
console.time('request');
console.log('Sending request to proxy');
const resp = await prover.sendRequest(
`${websocketProxyUrl}?token=${hostname}`,
{ url, method, headers, body },
);
console.log('Response:', resp);
console.log('Wait for transcript');
const transcript = await prover.transcript();
console.log('Transcript:', transcript);
console.timeEnd('request');
console.time('reveal');
const reveal = {
sent: [
transcript.ranges.sent.info!,
transcript.ranges.sent.headers!['connection'],
transcript.ranges.sent.headers!['host'],
transcript.ranges.sent.headers!['content-type'],
transcript.ranges.sent.headers!['content-length'],
...transcript.ranges.sent.lineBreaks,
],
recv: [
transcript.ranges.recv.info!,
transcript.ranges.recv.headers['server'],
transcript.ranges.recv.headers['date'],
transcript.ranges.recv.headers['content-type'],
transcript.ranges.recv.json!['name'],
transcript.ranges.recv.json!['eye_color'],
transcript.ranges.recv.json!['gender'],
...transcript.ranges.recv.lineBreaks,
],
};
console.log('Start reveal:', reveal);
await prover.reveal(reveal);
console.timeEnd('reveal');
console.log('Ready');
console.log('Unredacted data:', {
sent: transcript.sent,
received: transcript.recv,
});
setResult('Unredacted data successfully revealed to Verifier.');
setProcessing(false);
}, [setResult, setProcessing]);
return (
<div>
<h1>TLSNotary interactive prover demo</h1>
<div>
Before clicking the start button, make sure the{' '}
<i>interactive verifier</i> and the <i>web socket proxy</i> are running.
Check the README for the details.
</div>
<br />
<button onClick={!processing ? onClick : undefined} disabled={processing}>
Start Prover
</button>
<br />
<div>
<b>Proof: </b>
{!processing && !result ? (
<i>not started yet</i>
) : !result ? (
<>
Proving data from swapi...
<Watch
visible={true}
height="40"
width="40"
radius="48"
color="#000000"
ariaLabel="watch-loading"
wrapperStyle={{}}
wrapperClass=""
/>
Open <i>Developer tools</i> to follow progress
</>
) : (
<>
<pre>{JSON.stringify(result, null, 2)}</pre>
</>
)}
</div>
</div>
);
}

View File

@@ -0,0 +1,16 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>React/Typescrip Example</title>
</head>
<body>
<script>
</script>
<div id="root"></div>
</body>
</html>

View File

@@ -0,0 +1,30 @@
{
"name": "prover-ts",
"version": "1.0.0",
"description": "",
"main": "webpack.js",
"scripts": {
"dev": "webpack-dev-server --config webpack.js"
},
"author": "",
"license": "ISC",
"dependencies": {
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-loader-spinner": "^6.1.6",
"tlsn-js": "file:../.."
},
"devDependencies": {
"@types/react": "^18.0.26",
"@types/react-dom": "^18.0.10",
"babel-loader": "^9.1.3",
"copy-webpack-plugin": "^11.0.0",
"html-webpack-plugin": "^5.5.0",
"source-map-loader": "^5.0.0",
"ts-loader": "^9.4.2",
"typescript": "^4.9.4",
"webpack": "^5.75.0",
"webpack-cli": "^4.10.0",
"webpack-dev-server": "^4.11.1"
}
}

View File

@@ -0,0 +1,26 @@
{
"compilerOptions": {
"target": "es5",
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"allowJs": false,
"skipLibCheck": true,
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noFallthroughCasesInSwitch": true,
"module": "esnext",
"moduleResolution": "node",
"resolveJsonModule": true,
"noEmit": false,
"jsx": "react"
},
"include": [
"app.tsx",
"worker.ts"
]
}

View File

@@ -0,0 +1,110 @@
var webpack = require('webpack'),
path = require('path'),
CopyWebpackPlugin = require('copy-webpack-plugin'),
HtmlWebpackPlugin = require('html-webpack-plugin');
const ASSET_PATH = process.env.ASSET_PATH || '/';
var alias = {};
var fileExtensions = [
'jpg',
'jpeg',
'png',
'gif',
'eot',
'otf',
'svg',
'ttf',
'woff',
'woff2',
];
var options = {
ignoreWarnings: [
/Circular dependency between chunks with runtime/,
/ResizeObserver loop completed with undelivered notifications/,
],
mode: 'development',
entry: {
app: path.join(__dirname, 'app.tsx'),
},
output: {
filename: '[name].bundle.js',
path: path.resolve(__dirname, 'build'),
clean: true,
publicPath: ASSET_PATH,
},
module: {
rules: [
{
test: new RegExp('.(' + fileExtensions.join('|') + ')$'),
type: 'asset/resource',
exclude: /node_modules/,
},
{
test: /\.html$/,
loader: 'html-loader',
exclude: /node_modules/,
},
{
test: /\.(ts|tsx)$/,
exclude: /node_modules/,
use: [
{
loader: require.resolve('ts-loader'),
},
],
},
{
test: /\.(js|jsx)$/,
use: [
{
loader: 'source-map-loader',
},
{
loader: require.resolve('babel-loader'),
},
],
exclude: /node_modules/,
},
],
},
resolve: {
alias: alias,
extensions: fileExtensions
.map((extension) => '.' + extension)
.concat(['.js', '.jsx', '.ts', '.tsx', '.css']),
},
plugins: [
new CopyWebpackPlugin({
patterns: [
{
from: 'node_modules/tlsn-js/build',
to: path.join(__dirname, 'build'),
force: true,
},
],
}),
new HtmlWebpackPlugin({
template: path.join(__dirname, 'index.ejs'),
filename: 'index.html',
cache: false,
}),
new webpack.ProvidePlugin({
Buffer: ['buffer', 'Buffer'],
}),
].filter(Boolean),
// Required by wasm-bindgen-rayon, in order to use SharedArrayBuffer on the Web
// Ref:
// - https://github.com/GoogleChromeLabs/wasm-bindgen-rayon#setting-up
// - https://web.dev/i18n/en/coop-coep/
devServer: {
headers: {
'Cross-Origin-Embedder-Policy': 'require-corp',
'Cross-Origin-Opener-Policy': 'same-origin',
},
},
};
module.exports = options;

View File

@@ -0,0 +1,7 @@
import * as Comlink from 'comlink';
import init, { Prover } from 'tlsn-js';
Comlink.expose({
init,
Prover,
});

View File

@@ -0,0 +1,38 @@
[package]
name = "interactive-networked-verifier"
version = "0.1.0"
edition = "2021"
[dependencies]
async-trait = "0.1.67"
async-tungstenite = { version = "0.25", features = ["tokio-native-tls"] }
axum = { version = "0.7", features = ["ws"] }
axum-core = "0.4"
base64 = "0.21.0"
eyre = "0.6.12"
futures-util = "0.3.28"
http = { version = "1.1" }
http-body-util = { version = "0.1" }
hyper = { version = "1.1", features = ["client", "http1", "server"] }
hyper-util = { version = "0.1", features = ["full"] }
serde = { version = "1.0.147", features = ["derive"] }
sha1 = "0.10"
tokio = {version = "1", features = [
"rt",
"rt-multi-thread",
"macros",
"net",
"io-std",
"fs",
]}
tokio-util = { version = "0.7", features = ["compat"] }
tower = { version = "0.4.12", features = ["make"] }
tower-service = { version = "0.3" }
tracing = "0.1.40"
tracing-subscriber = { version ="0.3.18", features = ["env-filter"] }
ws_stream_tungstenite = { version = "0.13", features = ["tokio_io"] }
tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-core" }
tlsn-verifier = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-verifier" }
tlsn-common = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.7", package = "tlsn-common" }
tower-util = "0.3.1"

View File

@@ -0,0 +1,14 @@
# verifier-server
An implementation of the interactive verifier server in Rust.
## Running the server
1. Configure this server setting via the global variables defined in [main.rs](./src/main.rs) — please ensure that the hardcoded `SERVER_DOMAIN` and `VERIFICATION_SESSION_ID` have the same values on the prover side.
2. Start the server by running the following in a terminal at the root of this crate.
```bash
cargo run --release
```
## WebSocket APIs
### /verify
To perform verification via websocket, i.e. `ws://localhost:9816/verify`

View File

@@ -0,0 +1,930 @@
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.7.3/axum/src/extract/ws.rs
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other modification is commented with the prefix "NOTARY_MODIFICATION:"
//!
//! The code is under the following license:
//!
//! Copyright (c) 2019 Axum Contributors
//!
//! Permission is hereby granted, free of charge, to any
//! person obtaining a copy of this software and associated
//! documentation files (the "Software"), to deal in the
//! Software without restriction, including without
//! limitation the rights to use, copy, modify, merge,
//! publish, distribute, sublicense, and/or sell copies of
//! the Software, and to permit persons to whom the Software
//! is furnished to do so, subject to the following
//! conditions:
//!
//! The above copyright notice and this permission notice
//! shall be included in all copies or substantial portions
//! of the Software.
//!
//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
//! ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
//! TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
//! PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
//! SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
//! CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
//! OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
//! IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
//! DEALINGS IN THE SOFTWARE.
//!
//!
//! Handle WebSocket connections.
//!
//! # Example
//!
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
//! }
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! while let Some(msg) = socket.recv().await {
//! let msg = if let Ok(msg) = msg {
//! msg
//! } else {
//! // client disconnected
//! return;
//! };
//!
//! if socket.send(msg).await.is_err() {
//! // client disconnected
//! return;
//! }
//! }
//! }
//! # let _: Router = app;
//! ```
//!
//! # Passing data and/or state to an `on_upgrade` callback
//!
//! ```
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get,
//! Router,
//! };
//!
//! #[derive(Clone)]
//! struct AppState {
//! // ...
//! }
//!
//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
//! ws.on_upgrade(|socket| handle_socket(socket, state))
//! }
//!
//! async fn handle_socket(socket: WebSocket, state: AppState) {
//! // ...
//! }
//!
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .with_state(AppState { /* ... */ });
//! # let _: Router = app;
//! ```
//!
//! # Read and write concurrently
//!
//! If you need to read and write concurrently from a [`WebSocket`] you can use
//! [`StreamExt::split`]:
//!
//! ```rust,no_run
//! use axum::{Error, extract::ws::{WebSocket, Message}};
//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! let (mut sender, mut receiver) = socket.split();
//!
//! tokio::spawn(write(sender));
//! tokio::spawn(read(receiver));
//! }
//!
//! async fn read(receiver: SplitStream<WebSocket>) {
//! // ...
//! }
//!
//! async fn write(sender: SplitSink<WebSocket, Message>) {
//! // ...
//! }
//! ```
//!
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
#![allow(unused)]
use self::rejection::*;
use async_trait::async_trait;
use async_tungstenite::{
tokio::TokioAdapter,
tungstenite::{
self as ts,
protocol::{self, WebSocketConfig},
},
WebSocketStream,
};
use axum::{body::Bytes, extract::FromRequestParts, response::Response, Error};
use axum_core::body::Body;
use futures_util::{
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
};
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tracing::error;
/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
/// rejected.
///
/// See the [module docs](self) for an example.
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
}
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("config", &self.config)
.field("protocol", &self.protocol)
.field("sec_websocket_key", &self.sec_websocket_key)
.field("sec_websocket_protocol", &self.sec_websocket_protocol)
.finish_non_exhaustive()
}
}
impl<F> WebSocketUpgrade<F> {
/// The target minimum size of the write buffer to reach before writing the data
/// to the underlying stream.
///
/// The default value is 128 KiB.
///
/// If set to `0` each message will be eagerly written to the underlying stream.
/// It is often more optimal to allow them to buffer a little, hence the default value.
///
/// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
pub fn write_buffer_size(mut self, size: usize) -> Self {
self.config.write_buffer_size = size;
self
}
/// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors.
///
/// The default value is unlimited.
///
/// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
/// when writes to the underlying stream are failing. So the **write buffer can not
/// fill up if you are not observing write errors even if not flushing**.
///
/// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
/// and probably a little more depending on error handling strategy.
pub fn max_write_buffer_size(mut self, max: usize) -> Self {
self.config.max_write_buffer_size = max;
self
}
/// Set the maximum message size (defaults to 64 megabytes)
pub fn max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
/// Set the maximum frame size (defaults to 16 megabytes)
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
/// Allow server to accept unmasked frames (defaults to false)
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
self.config.accept_unmasked_frames = accept;
self
}
/// Set the known protocols.
///
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
/// return the protocol name.
///
/// The protocols should be listed in decreasing order of preference: if the client offers
/// multiple protocols that the server could support, the server will pick the first one in
/// this list.
///
/// # Examples
///
/// ```
/// use axum::{
/// extract::ws::{WebSocketUpgrade, WebSocket},
/// routing::get,
/// response::{IntoResponse, Response},
/// Router,
/// };
///
/// let app = Router::new().route("/ws", get(handler));
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
/// .on_upgrade(|socket| async {
/// // ...
/// })
/// }
/// # let _: Router = app;
/// ```
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
if let Some(req_protocols) = self
.sec_websocket_protocol
.as_ref()
.and_then(|p| p.to_str().ok())
{
self.protocol = protocols
.into_iter()
// FIXME: This will often allocate a new `String` and so is less efficient than it
// could be. But that can't be fixed without breaking changes to the public API.
.map(Into::into)
.find(|protocol| {
req_protocols
.split(',')
.any(|req_protocol| req_protocol.trim() == protocol)
})
.map(|protocol| match protocol {
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
Cow::Borrowed(s) => HeaderValue::from_static(s),
});
}
self
}
/// Provide a callback to call if upgrading the connection fails.
///
/// The connection upgrade is performed in a background task. If that fails this callback
/// will be called.
///
/// By default any errors will be silently ignored.
///
/// # Example
///
/// ```
/// use axum::{
/// extract::{WebSocketUpgrade},
/// response::Response,
/// };
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.on_failed_upgrade(|error| {
/// report_error(error);
/// })
/// .on_upgrade(|socket| async { /* ... */ })
/// }
/// #
/// # fn report_error(_: axum::Error) {}
/// ```
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
where
C: OnFailedUpgrade,
{
WebSocketUpgrade {
config: self.config,
protocol: self.protocol,
sec_websocket_key: self.sec_websocket_key,
on_upgrade: self.on_upgrade,
on_failed_upgrade: callback,
sec_websocket_protocol: self.sec_websocket_protocol,
}
}
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
#[must_use = "to set up the WebSocket connection, this response must be returned"]
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
C: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
F: OnFailedUpgrade,
{
let on_upgrade = self.on_upgrade;
let config = self.config;
let on_failed_upgrade = self.on_failed_upgrade;
let protocol = self.protocol.clone();
tokio::spawn(async move {
let upgraded = match on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
error!("Something wrong with on_upgrade: {:?}", err);
on_failed_upgrade.call(Error::new(err));
return;
}
};
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(
// NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite
TokioAdapter::new(upgraded),
protocol::Role::Server,
Some(config),
)
.await;
let socket = WebSocket {
inner: socket,
protocol,
};
callback(socket).await;
});
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(Body::empty()).unwrap()
}
}
/// What to do when a connection upgrade fails.
///
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
pub trait OnFailedUpgrade: Send + 'static {
/// Call the callback.
fn call(self, error: Error);
}
impl<F> OnFailedUpgrade for F
where
F: FnOnce(Error) + Send + 'static,
{
fn call(self, error: Error) {
self(error)
}
}
/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
///
/// It simply ignores the error.
#[non_exhaustive]
#[derive(Debug)]
pub struct DefaultOnFailedUpgrade;
impl OnFailedUpgrade for DefaultOnFailedUpgrade {
#[inline]
fn call(self, _error: Error) {}
}
#[async_trait]
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
where
S: Send + Sync,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}
let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();
let on_upgrade = parts
.extensions
.remove::<hyper::upgrade::OnUpgrade>()
.ok_or(ConnectionNotUpgradable)?;
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self {
config: Default::default(),
protocol: None,
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
on_failed_upgrade: DefaultOnFailedUpgrade,
})
}
}
/// NOTARY_MODIFICATION: Made this function public to be used in service.rs
pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = headers.get(&key) {
header
} else {
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}
/// A stream of WebSocket messages.
///
/// See [the module level documentation](self) for more details.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<TokioAdapter<TokioIo<hyper::upgrade::Upgraded>>>,
protocol: Option<HeaderValue>,
}
impl WebSocket {
/// NOTARY_MODIFICATION: Consume `self` and get the inner [`async_tungstenite::WebSocketStream`].
pub fn into_inner(self) -> WebSocketStream<TokioAdapter<TokioIo<hyper::upgrade::Upgraded>>> {
self.inner
}
/// Receive another message.
///
/// Returns `None` if the stream has closed.
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
self.next().await
}
/// Send a message.
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
self.inner
.send(msg.into_tungstenite())
.await
.map_err(Error::new)
}
/// Gracefully close this WebSocket.
pub async fn close(mut self) -> Result<(), Error> {
self.inner.close(None).await.map_err(Error::new)
}
/// Return the selected WebSocket subprotocol, if one has been chosen.
pub fn protocol(&self) -> Option<&HeaderValue> {
self.protocol.as_ref()
}
}
impl Stream for WebSocket {
type Item = Result<Message, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
Some(Ok(msg)) => {
if let Some(msg) = Message::from_tungstenite(msg) {
return Poll::Ready(Some(Ok(msg)));
}
}
Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
None => return Poll::Ready(None),
}
}
}
}
impl Sink<Message> for WebSocket {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.inner)
.start_send(item.into_tungstenite())
.map_err(Error::new)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
}
}
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
pub type CloseCode = u16;
/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CloseFrame<'t> {
/// The reason as a code.
pub code: CloseCode,
/// The reason as text string.
pub reason: Cow<'t, str>,
}
/// A WebSocket message.
//
// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
// Copyright (c) 2017 Alexey Galakhov
// Copyright (c) 2016 Jason Housley
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
/// A text WebSocket message
Text(String),
/// A binary WebSocket message
Binary(Vec<u8>),
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Ping messages will be automatically responded to by the server, so you do not have to worry
/// about dealing with them yourself.
Ping(Vec<u8>),
/// A pong message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Pong messages will be automatically sent to the client if a ping message is received, so
/// you do not have to worry about constructing them yourself unless you want to implement a
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
Pong(Vec<u8>),
/// A close message with the optional close frame.
Close(Option<CloseFrame<'static>>),
}
impl Message {
fn into_tungstenite(self) -> ts::Message {
match self {
Self::Text(text) => ts::Message::Text(text),
Self::Binary(binary) => ts::Message::Binary(binary),
Self::Ping(ping) => ts::Message::Ping(ping),
Self::Pong(pong) => ts::Message::Pong(pong),
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
code: ts::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
})),
Self::Close(None) => ts::Message::Close(None),
}
}
fn from_tungstenite(message: ts::Message) -> Option<Self> {
match message {
ts::Message::Text(text) => Some(Self::Text(text)),
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
}))),
ts::Message::Close(None) => Some(Self::Close(None)),
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
// https://github.com/snapview/tungstenite-rs/issues/268
ts::Message::Frame(_) => None,
}
}
/// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> {
match self {
Self::Text(string) => string.into_bytes(),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
Self::Close(None) => Vec::new(),
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
}
}
/// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String, Error> {
match self {
Self::Text(string) => Ok(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
.map_err(|err| err.utf8_error())
.map_err(Error::new)?),
Self::Close(None) => Ok(String::new()),
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
}
}
/// Attempt to get a &str from the WebSocket message,
/// this will try to convert binary data to utf8.
pub fn to_text(&self) -> Result<&str, Error> {
match *self {
Self::Text(ref string) => Ok(string),
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
Ok(std::str::from_utf8(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(""),
Self::Close(Some(ref frame)) => Ok(&frame.reason),
}
}
}
impl From<String> for Message {
fn from(string: String) -> Self {
Message::Text(string)
}
}
impl<'s> From<&'s str> for Message {
fn from(string: &'s str) -> Self {
Message::Text(string.into())
}
}
impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Self {
Message::Binary(data.into())
}
}
impl From<Vec<u8>> for Message {
fn from(data: Vec<u8>) -> Self {
Message::Binary(data)
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_data()
}
}
fn sign(key: &[u8]) -> HeaderValue {
use base64::engine::Engine as _;
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
pub mod rejection {
//! WebSocket specific rejections.
use axum_core::{
__composite_rejection as composite_rejection, __define_rejection as define_rejection,
};
define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `GET`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotGet;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidConnectionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Upgrade` header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidUpgradeHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidWebSocketVersionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Key` header missing"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct WebSocketKeyHeaderMissing;
}
define_rejection! {
#[status = UPGRADE_REQUIRED]
#[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
///
/// This rejection is returned if the connection cannot be upgraded for example if the
/// request is HTTP/1.0.
///
/// See [MDN] for more details about connection upgrades.
///
/// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
pub struct ConnectionNotUpgradable;
}
composite_rejection! {
/// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
///
/// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
ConnectionNotUpgradable,
}
}
}
pub mod close_code {
//! Constants for [`CloseCode`]s.
//!
//! [`CloseCode`]: super::CloseCode
/// Indicates a normal closure, meaning that the purpose for which the connection was
/// established has been fulfilled.
pub const NORMAL: u16 = 1000;
/// Indicates that an endpoint is "going away", such as a server going down or a browser having
/// navigated away from a page.
pub const AWAY: u16 = 1001;
/// Indicates that an endpoint is terminating the connection due to a protocol error.
pub const PROTOCOL: u16 = 1002;
/// Indicates that an endpoint is terminating the connection because it has received a type of
/// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
/// it receives a binary message).
pub const UNSUPPORTED: u16 = 1003;
/// Indicates that no status code was included in a closing frame.
pub const STATUS: u16 = 1005;
/// Indicates an abnormal closure.
pub const ABNORMAL: u16 = 1006;
/// Indicates that an endpoint is terminating the connection because it has received data
/// within a message that was not consistent with the type of the message (e.g., non-UTF-8
/// RFC3629 data within a text message).
pub const INVALID: u16 = 1007;
/// Indicates that an endpoint is terminating the connection because it has received a message
/// that violates its policy. This is a generic status code that can be returned when there is
/// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
/// hide specific details about the policy.
pub const POLICY: u16 = 1008;
/// Indicates that an endpoint is terminating the connection because it has received a message
/// that is too big for it to process.
pub const SIZE: u16 = 1009;
/// Indicates that an endpoint (client) is terminating the connection because it has expected
/// the server to negotiate one or more extension, but the server didn't return them in the
/// response message of the WebSocket handshake. The list of extensions that are needed should
/// be given as the reason for closing. Note that this status code is not used by the server,
/// because it can fail the WebSocket handshake instead.
pub const EXTENSION: u16 = 1010;
/// Indicates that a server is terminating the connection because it encountered an unexpected
/// condition that prevented it from fulfilling the request.
pub const ERROR: u16 = 1011;
/// Indicates that the server is restarting.
pub const RESTART: u16 = 1012;
/// Indicates that the server is overloaded and the client should either connect to a different
/// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
/// action.
pub const AGAIN: u16 = 1013;
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, routing::get, Router};
use http::{Request, Version};
// NOTARY_MODIFICATION: use tower_util instead of tower to make clippy happy
use tower_util::ServiceExt;
#[tokio::test]
async fn rejects_http_1_0_requests() {
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let rejection = ws.unwrap_err();
assert!(matches!(
rejection,
WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
));
std::future::ready(())
});
let req = Request::builder()
.version(Version::HTTP_10)
.method(Method::GET)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
.header("sec-websocket-version", "13")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[allow(dead_code)]
fn default_on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
#[allow(dead_code)]
fn on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
}

View File

@@ -0,0 +1,176 @@
use axum::{
extract::{Request, State},
response::IntoResponse,
routing::get,
Router,
};
use axum_websocket::{WebSocket, WebSocketUpgrade};
use eyre::eyre;
use hyper::{body::Incoming, server::conn::http1};
use hyper_util::rt::TokioIo;
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use tlsn_common::config::ProtocolConfigValidator;
use tlsn_verifier::{SessionInfo, Verifier, VerifierConfig};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpListener,
};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tower_service::Service;
use tracing::{debug, error, info};
use ws_stream_tungstenite::WsStream;
mod axum_websocket;
// Maximum number of bytes that can be sent from prover to server
const MAX_SENT_DATA: usize = 1 << 12;
// Maximum number of bytes that can be received by prover from server
const MAX_RECV_DATA: usize = 1 << 14;
/// Global data that needs to be shared with the axum handlers
#[derive(Clone, Debug)]
struct VerifierGlobals {
pub server_domain: String,
}
pub async fn run_server(
verifier_host: &str,
verifier_port: u16,
server_domain: &str,
) -> Result<(), eyre::ErrReport> {
let verifier_address = SocketAddr::new(
IpAddr::V4(verifier_host.parse().map_err(|err| {
eyre!("Failed to parse verifer host address from server config: {err}")
})?),
verifier_port,
);
let listener = TcpListener::bind(verifier_address)
.await
.map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?;
info!("Listening for TCP traffic at {}", verifier_address);
let protocol = Arc::new(http1::Builder::new());
let router = Router::new()
.route("/verify", get(ws_handler))
.with_state(VerifierGlobals {
server_domain: server_domain.to_string(),
});
loop {
let stream = match listener.accept().await {
Ok((stream, _)) => stream,
Err(err) => {
error!("Failed to connect to prover: {err}");
continue;
}
};
debug!("Received a prover's TCP connection");
let tower_service = router.clone();
let protocol = protocol.clone();
tokio::spawn(async move {
info!("Accepted prover's TCP connection",);
// Reference: https://github.com/tokio-rs/axum/blob/5201798d4e4d4759c208ef83e30ce85820c07baa/examples/low-level-rustls/src/main.rs#L67-L80
let io = TokioIo::new(stream);
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
tower_service.clone().call(request)
});
// Serve different requests using the same hyper protocol and axum router
let _ = protocol
.serve_connection(io, hyper_service)
// use with_upgrades to upgrade connection to websocket for websocket clients
// and to extract tcp connection for tcp clients
.with_upgrades()
.await;
});
}
}
async fn ws_handler(
ws: WebSocketUpgrade,
State(verifier_globals): State<VerifierGlobals>,
) -> impl IntoResponse {
info!("Received websocket request");
ws.on_upgrade(|socket| handle_socket(socket, verifier_globals))
}
async fn handle_socket(socket: WebSocket, verifier_globals: VerifierGlobals) {
debug!("Upgraded to websocket connection");
let stream = WsStream::new(socket.into_inner());
match verifier(stream, &verifier_globals.server_domain).await {
Ok((sent, received, _session_info)) => {
info!("Successfully verified {}", &verifier_globals.server_domain);
info!("Verified sent data:\n{}", sent,);
println!("Verified received data:\n{}", received,);
}
Err(err) => {
error!("Failed verification using websocket: {err}");
}
}
}
async fn verifier<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
socket: T,
server_domain: &str,
) -> Result<(String, String, SessionInfo), eyre::ErrReport> {
debug!("Starting verification...");
// Setup Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap();
let verifier_config = VerifierConfig::builder()
.protocol_config_validator(config_validator)
.build()
.unwrap();
let verifier = Verifier::new(verifier_config);
// Verify MPC-TLS and wait for (redacted) data.
debug!("Starting MPC-TLS verification...");
// Verify MPC-TLS and wait for (redacted) data.
let (mut partial_transcript, session_info) = verifier.verify(socket.compat()).await.unwrap();
partial_transcript.set_unauthed(0);
// Check sent data: check host.
debug!("Starting sent data verification...");
let sent = partial_transcript.sent_unsafe().to_vec();
let sent_data = String::from_utf8(sent.clone()).expect("Verifier expected sent data");
sent_data
.find(server_domain)
.ok_or_else(|| eyre!("Verification failed: Expected host {}", server_domain))?;
// Check received data: check json and version number.
debug!("Starting received data verification...");
let received = partial_transcript.received_unsafe().to_vec();
let response = String::from_utf8(received.clone()).expect("Verifier expected received data");
debug!("Received data: {:?}", response);
response
.find("eye_color")
.ok_or_else(|| eyre!("Verification failed: missing eye_color in received data"))?;
// Check Session info: server name.
if session_info.server_name.as_str() != server_domain {
return Err(eyre!("Verification failed: server name mismatches"));
}
let sent_string = bytes_to_redacted_string(&sent)?;
let received_string = bytes_to_redacted_string(&received)?;
Ok((sent_string, received_string, session_info))
}
/// Render redacted bytes as `🙈`.
fn bytes_to_redacted_string(bytes: &[u8]) -> Result<String, eyre::ErrReport> {
Ok(String::from_utf8(bytes.to_vec())
.map_err(|err| eyre!("Failed to parse bytes to redacted string: {err}"))?
.replace('\0', "🙈"))
}

View File

@@ -0,0 +1,22 @@
use interactive_networked_verifier::run_server;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
const TRACING_FILTER: &str = "INFO";
const VERIFIER_HOST: &str = "0.0.0.0";
const VERIFIER_PORT: u16 = 9816;
/// Make sure the following domain is the same in SERVER_URL on the prover side
const SERVER_DOMAIN: &str = "swapi.dev";
#[tokio::main]
async fn main() -> Result<(), eyre::ErrReport> {
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| TRACING_FILTER.into()))
.with(tracing_subscriber::fmt::layer())
.init();
run_server(VERIFIER_HOST, VERIFIER_PORT, SERVER_DOMAIN).await?;
Ok(())
}

View File

@@ -8,7 +8,6 @@
"files": [
"build/",
"src/",
"wasm/pkg/*",
"readme.md"
],
"scripts": {

View File

@@ -56,9 +56,9 @@ export default async function init(config?: {
});
// 6422528 ~= 6.12 mb
debug('res.memory=', res.memory);
debug('res.memory.buffer.length=', res.memory.buffer.byteLength);
debug('DEBUG', 'initialize thread pool');
debug('res.memory', res.memory);
debug('res.memory.buffer.length', res.memory.buffer.byteLength);
debug('initialize thread pool');
await initThreadPool(hardwareConcurrency);
debug('initialized thread pool');
@@ -273,6 +273,7 @@ export class Prover {
async reveal(reveal: Reveal) {
return this.#prover.reveal(reveal);
}
}
export class Verifier {