Ns 2 http websocket (#6)

* Setup repo.

* Correct typo.

* Add tcp listener with tls.

* Hook up notary service.

* Fix notary signing key loading.

* Add otel tracing.

* Add test to test prover and notary integration.

* Fix notarization test, add comments.

* Fix github action.

* Fix span logging, github actions.

* Add logic to promote to http and then downgrade to tcp for notarization.

* Fix client hang issue

* Change channel message type.

* Fix response parsing from notary.

* Fix websocket implementation and use upgrade protocol for raw tcp.

* Modify test to mimick browser extension for websocket test.

* Refactor tcp client handling.

* Add global store for persistent data.

* Finish websocket handler and test.

* Add comments.

* Add more comments and documentation.

* Add openapi.yaml.

* Modify README.

* Add architecture explanation.

* Modify README.

* Fix PR based on comments.

* Combine tcp and websocket extractors.

* Refactor and fix documentations.
This commit is contained in:
Christopher Chong
2023-08-14 12:38:20 +08:00
committed by GitHub
parent 0e9fadce01
commit 3fcd517c7f
18 changed files with 1859 additions and 84 deletions

View File

@@ -6,10 +6,17 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
async-trait = "0.1.67"
async-tungstenite = { version = "0.22.2", features = ["tokio-runtime", "tokio-native-tls"] }
axum = { version = "0.6.18", features = ["ws"]}
axum-core = "0.3.4"
axum-macros = "0.3.8"
base64 = "0.21.0"
eyre = "0.6.8"
futures = "0.3"
futures-util = "0.3.28"
hyper = { version = "0.14", features = ["client", "http1"] }
http = "0.2.9"
hyper = { version = "0.14", features = ["client", "http1", "server", "tcp"] }
opentelemetry = { version = "0.19" }
p256 = "0.13"
rustls = { version = "0.21" }
@@ -17,15 +24,23 @@ rustls-pemfile = { version = "1.0.2" }
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0"
serde_yaml = "0.9.21"
sha1 = "0.10"
structopt = "0.3.26"
thiserror = "1"
tls-server-fixture = { git = "https://github.com/tlsnotary/tlsn" }
tlsn-notary = { git = "https://github.com/tlsnotary/tlsn" }
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn" }
tlsn-tls-core = { git = "https://github.com/tlsnotary/tlsn" }
tokio = { version = "1", features = ["full"] }
tokio-rustls = { version = "0.24.1" }
tokio-util = { version = "0.7", features = ["compat"] }
tower = { version = "0.4.12", features = ["make"] }
tracing = "0.1"
tracing-opentelemetry = "0.19"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.4.1", features = ["v4", "fast-rng"] }
ws_stream_tungstenite = { version = "0.10.0", features = ["tokio_io"] }
[dev-dependencies]
hyper-tls = "0.5.0"
tls-server-fixture = { git = "https://github.com/tlsnotary/tlsn" }
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn" }
tokio-native-tls = "0.3.1"

View File

@@ -8,3 +8,53 @@ An implementation of the notary server in Rust.
## ⚠️ Notice
This project is currently under active development and should not be used in production. Expect bugs and regular major breaking changes.
---
## Running the server
1. Configure the server setting in this [file](./config.yaml) — refer [here](./src/config.rs) for more information on the definition of the setting parameters.
2. Start the server by running following in a terminal at the top level of this project.
```bash
cargo run
```
3. To use a config file from a different location, run the following command to override the default config file location.
```bash
cargo run -- --config-file <path-to-new-config-file>
```
---
## API
All APIs are TLS-protected, hence please use `https://` or `wss://`.
### HTTP APIs
Defined in the [OpenAPI specification](./openapi.yaml).
### WebSocket APIs
#### /notarize
##### Description
To perform notarization using the session id (unique id returned upon calling the `/session` endpoint successfully) submitted as a custom header.
##### Custom Header
`X-Session-Id`
##### Custom Header Type
String
---
## Architecture
### Objective
The main objective of a notary server is to perform notarization together with a prover. In this case, the prover can either be
1. TCP client — which has access and control over the transport layer, i.e. TCP
2. WebSocket client — which has no access over TCP and instead uses WebSocket for notarization
### Design Choices
#### Web Framework
Axum is chosen as the framework to serve HTTP and WebSocket requests from the prover clients due to its rich and well supported features, e.g. native integration with Tokio/Hyper/Tower, customizable middleware, ability to support lower level integration of TLS ([example](https://github.com/tokio-rs/axum/blob/main/examples/low-level-rustls/src/main.rs)). To simplify the notary server setup, a single Axum router is used to support both HTTP and WebSocket connections, i.e. all requests can be made to the same port of the notary server.
#### Notarization Configuration
To perform notarization, some parameters need to be configured by the prover and notary server (more details in the [OpenAPI specification](./openapi.yaml)), i.e.
- maximum transcript size
- unique session id
To streamline this process, a single HTTP endpoint (`/session`) is used by both TCP and WebSocket clients.
#### WebSocket
Axum's internal implementation of WebSocket uses [tokio_tungstenite](https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/), which provides a WebSocket struct that doesn't implement [AsyncRead](https://docs.rs/futures/latest/futures/io/trait.AsyncRead.html) and [AsyncWrite](https://docs.rs/futures/latest/futures/io/trait.AsyncWrite.html). Both these traits are required by TLSN core libraries for prover and notary. To overcome this, a [slight modification](./src/service/axum_websocket.rs) of Axum's implementation of WebSocket is used, where [async_tungstenite](https://docs.rs/async-tungstenite/latest/async_tungstenite/) is used instead so that [ws_stream_tungstenite](https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html) can be used to wrap on top of the WebSocket struct to get AsyncRead and AsyncWrite implemented.

View File

@@ -3,6 +3,9 @@ server:
domain: "127.0.0.1"
port: 7047
notarization:
max-transcript-size: 16384
tls-signature:
private-key-pem-path: "./src/fixture/tls/notary.key"
certificate-pem-path: "./src/fixture/tls/notary.crt"

123
openapi.yaml Normal file
View File

@@ -0,0 +1,123 @@
openapi: 3.0.0
info:
title: Notary Server
description: Notary server written in Rust to provide notarization service.
version: 0.1.0
tags:
- name: Notarization
paths:
/session:
post:
tags:
- Notarization
description: Initialize and configure notarization for both TCP and WebSocket clients
parameters:
- in: header
name: Content-Type
description: The value must be application/json
schema:
type: string
enum:
- "application/json"
required: true
requestBody:
description: Notarization session request to server
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/NotarizationSessionRequest"
responses:
"200":
description: Notarization session response from server
content:
application/json:
schema:
$ref: "#/components/schemas/NotarizationSessionResponse"
"400":
description: Configuration parameters or headers provided by prover are invalid
content:
text/plain:
schema:
type: string
example: "Invalid request from prover: Failed to deserialize the JSON body into the target type"
"500":
description: There was some internal error when processing
content:
text/plain:
schema:
type: string
example: "Something is wrong"
/notarize:
get:
tags:
- Notarization
description: Start notarization for TCP client
parameters:
- in: header
name: Connection
description: The value should be 'Upgrade'
schema:
type: string
enum:
- "Upgrade"
required: true
- in: header
name: Upgrade
description: The value should be 'TCP'
schema:
type: string
enum:
- "TCP"
required: true
- in: header
name: X-Session-Id
description: Unique ID returned from server upon calling POST /session
schema:
type: string
required: true
responses:
"101":
description: Switching protocol response
"400":
description: Headers provided by prover are invalid
content:
text/plain:
schema:
type: string
example: "Invalid request from prover: Upgrade header is not set for client"
"500":
description: There was some internal error when processing
content:
text/plain:
schema:
type: string
example: "Something is wrong"
components:
schemas:
NotarizationSessionRequest:
type: object
properties:
clientType:
description: Types of client that the prover is using
type: string
enum:
- "Tcp"
- "Websocket"
maxTranscriptSize:
description: Maximum transcript size in bytes
type: integer
required:
- "clientType"
- "maxTranscriptSize"
NotarizationSessionResponse:
type: object
properties:
sessionId:
type: string
required:
- "sessionId"

View File

@@ -5,6 +5,8 @@ use serde::Deserialize;
pub struct NotaryServerProperties {
/// Name and address of the notary server
pub server: ServerProperties,
/// Setting for notarization
pub notarization: NotarizationProperties,
/// File path of private key and certificate (in PEM format) used for establishing TLS with prover
pub tls_signature: TLSSignatureProperties,
/// File path of private key (in PEM format) used to sign the notarisation
@@ -13,6 +15,13 @@ pub struct NotaryServerProperties {
pub tracing: TracingProperties,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct NotarizationProperties {
/// Global limit for maximum transcript size in bytes
pub max_transcript_size: usize,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct ServerProperties {

View File

@@ -1 +1,2 @@
pub mod cli;
pub mod notary;

View File

@@ -5,6 +5,6 @@ use structopt::StructOpt;
#[structopt(name = "Notary Server")]
pub struct CliFields {
/// Configuration file location
#[structopt(long, default_value = "./src/config/config.yaml")]
#[structopt(long, default_value = "./config.yaml")]
pub config_file: String,
}

55
src/domain/notary.rs Normal file
View File

@@ -0,0 +1,55 @@
use std::{collections::HashMap, sync::Arc};
use p256::ecdsa::SigningKey;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::config::NotarizationProperties;
/// Response object of the /session API
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NotarizationSessionResponse {
/// Unique session id that is generated by notary and shared to prover
pub session_id: String,
}
/// Request object of the /session API
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NotarizationSessionRequest {
pub client_type: ClientType,
/// Maximum transcript size in bytes
pub max_transcript_size: Option<usize>,
}
/// Types of client that the prover is using
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ClientType {
/// Client that has access to the transport layer
Tcp,
/// Client that cannot directly access transport layer, e.g. browser extension
Websocket,
}
/// Global data that needs to be shared with the axum handlers
#[derive(Clone, Debug)]
pub struct NotaryGlobals {
pub notary_signing_key: SigningKey,
pub notarization_config: NotarizationProperties,
/// A temporary storage to store configuration data, mainly used for WebSocket client
pub store: Arc<Mutex<HashMap<String, Option<usize>>>>,
}
impl NotaryGlobals {
pub fn new(
notary_signing_key: SigningKey,
notarization_config: NotarizationProperties,
) -> Self {
Self {
notary_signing_key,
notarization_config,
store: Default::default(),
}
}
}

View File

@@ -1,3 +1,7 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use eyre::Report;
use std::error::Error;
@@ -11,6 +15,8 @@ pub enum NotaryServerError {
Connection(String),
#[error("Error occurred during notarization: {0}")]
Notarization(Box<dyn Error + Send + 'static>),
#[error("Invalid request from prover: {0}")]
BadProverRequest(String),
}
impl From<NotaryError> for NotaryServerError {
@@ -24,3 +30,19 @@ impl From<NotaryConfigBuilderError> for NotaryServerError {
Self::Notarization(Box::new(error))
}
}
/// Trait implementation to convert this error into an axum http response
impl IntoResponse for NotaryServerError {
fn into_response(self) -> Response {
match self {
bad_request_error @ NotaryServerError::BadProverRequest(_) => {
(StatusCode::BAD_REQUEST, bad_request_error.to_string()).into_response()
}
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
"Something wrong happened.",
)
.into_response(),
}
}
}

View File

@@ -3,14 +3,18 @@ mod domain;
mod error;
mod server;
mod server_tracing;
mod service;
mod util;
pub use config::{
NotaryServerProperties, NotarySignatureProperties, ServerProperties, TLSSignatureProperties,
TracingProperties,
NotarizationProperties, NotaryServerProperties, NotarySignatureProperties, ServerProperties,
TLSSignatureProperties, TracingProperties,
};
pub use domain::{
cli::CliFields,
notary::{ClientType, NotarizationSessionRequest, NotarizationSessionResponse},
};
pub use domain::cli::CliFields;
pub use error::NotaryServerError;
pub use server::{read_pem_file, run_tcp_server};
pub use server::{read_pem_file, run_server};
pub use server_tracing::init_tracing;
pub use util::parse_config_file;

View File

@@ -3,7 +3,7 @@ use structopt::StructOpt;
use tracing::debug;
use notary_server::{
init_tracing, parse_config_file, run_tcp_server, CliFields, NotaryServerError,
init_tracing, parse_config_file, run_server, CliFields, NotaryServerError,
NotaryServerProperties,
};
@@ -18,8 +18,8 @@ async fn main() -> Result<(), NotaryServerError> {
debug!(?config, "Server config loaded");
// Run the tcp server
run_tcp_server(&config).await?;
// Run the server
run_server(&config).await?;
Ok(())
}

View File

@@ -1,47 +1,55 @@
use axum::{
http::{Request, StatusCode},
response::IntoResponse,
routing::{get, post},
Router,
};
use eyre::{ensure, eyre, Result};
use futures_util::future::poll_fn;
use p256::{
ecdsa::{Signature, SigningKey},
pkcs8::DecodePrivateKey,
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, Http},
};
use p256::{ecdsa::SigningKey, pkcs8::DecodePrivateKey};
use rustls::{Certificate, PrivateKey, ServerConfig};
use std::{
fs::File as StdFile,
io::BufReader,
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
};
use tlsn_notary::{bind_notary, NotaryConfig};
use tokio::{
fs::File,
io::{AsyncRead, AsyncWrite},
net::TcpListener,
};
use tokio::{fs::File, net::TcpListener};
use tokio_rustls::TlsAcceptor;
use tokio_util::compat::TokioAsyncReadCompatExt;
use tower::MakeService;
use tracing::{debug, error, info};
use crate::{
config::{NotaryServerProperties, NotarySignatureProperties, TLSSignatureProperties},
domain::notary::NotaryGlobals,
error::NotaryServerError,
service::{initialize, upgrade_protocol},
};
/// Start a TLS-secured TCP server to accept notarization request
/// Start a TLS-secured TCP server to accept notarization request for both TCP and WebSocket clients
#[tracing::instrument(skip(config))]
pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> {
pub async fn run_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> {
// Load the private key and cert needed for TLS connection from fixture folder — can be swapped out when we stop using static self signed cert
let (tls_private_key, tls_certificates) = load_tls_key_and_cert(&config.tls_signature).await?;
// Load the private key for notarized transcript signing from fixture folder — can be swapped out when we use proper ephemeral signing key
let notary_signing_key = load_notary_signing_key(&config.notary_signature).await?;
// Build a TCP listener with TLS enabled
let tls_config = Arc::new(
ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(tls_certificates, tls_private_key)
.map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?,
);
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(tls_certificates, tls_private_key)
.map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?;
// Set the http protocols we support
server_config.alpn_protocols = vec![b"http/1.1".to_vec()];
let tls_config = Arc::new(server_config);
let notary_address = SocketAddr::new(
IpAddr::V4(config.server.domain.parse().map_err(|err| {
@@ -54,25 +62,42 @@ pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), Notar
let listener = TcpListener::bind(notary_address)
.await
.map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?;
let mut listener = AddrIncoming::from_listener(listener)
.map_err(|err| eyre!("Failed to build hyper tcp listener: {err}"))?;
info!(
"Listening for TLS-secured TCP traffic at {}",
notary_address
);
let protocol = Arc::new(Http::new());
let notary_globals = NotaryGlobals::new(notary_signing_key, config.notarization.clone());
let router = Router::new()
.route(
"/healthcheck",
get(|| async move { (StatusCode::OK, "Ok").into_response() }),
)
.route("/session", post(initialize))
.route("/notarize", get(upgrade_protocol))
.with_state(notary_globals);
let mut app = router.into_make_service();
loop {
// Poll for any incoming connection constantly
let (stream, prover_address) = match poll_fn(|cx| listener.poll_accept(cx)).await {
Ok(connection) => connection,
Err(err) => {
error!("{}", NotaryServerError::Connection(err.to_string()));
continue;
}
};
// Poll and await for any incoming connection, ensure that all operations inside are infallible to prevent bringing down the server
let (prover_address, stream) =
match poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)).await {
Some(Ok(connection)) => (connection.remote_addr(), connection),
Some(Err(err)) => {
error!("{}", NotaryServerError::Connection(err.to_string()));
continue;
}
None => unreachable!("The poll_accept method should never return None"),
};
debug!(?prover_address, "Received a prover's TCP connection");
let acceptor = acceptor.clone();
let notary_signing_key = notary_signing_key.clone();
let protocol = protocol.clone();
let service = MakeService::<_, Request<hyper::Body>>::make_service(&mut app, &stream);
// Spawn a new async task to handle the new connection
tokio::spawn(async move {
@@ -82,16 +107,14 @@ pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), Notar
?prover_address,
"Accepted prover's TLS-secured TCP connection",
);
match notary_service(stream, &prover_address.to_string(), &notary_signing_key)
.await
{
Ok(_) => {
info!(?prover_address, "Successful notarization!");
}
Err(err) => {
error!(?prover_address, "Failed notarization: {err}");
}
}
// Serve different requests using the same hyper protocol and axum router
let _ = protocol
// Can unwrap because it's infallible
.serve_connection(stream, service.await.unwrap())
// use with_upgrades to upgrade connection to websocket for websocket clients
// and to extract tcp connection for tcp clients
.with_upgrades()
.await;
}
Err(err) => {
error!(
@@ -105,22 +128,6 @@ pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), Notar
}
}
/// Run the notarization
async fn notary_service<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
prover_address: &str,
signing_key: &SigningKey,
) -> Result<(), NotaryServerError> {
debug!(?prover_address, "Starting notarization...");
// Temporarily use the prover address as the notarization session id as it is unique for each prover
let config = NotaryConfig::builder().id(prover_address).build()?;
let (notary, notary_fut) = bind_notary(config, socket.compat())?;
// Run the notary and background processes concurrently
tokio::try_join!(notary_fut, notary.notarize::<Signature>(signing_key),).map(|_| Ok(()))?
}
/// Temporary function to load notary signing key from static file
async fn load_notary_signing_key(config: &NotarySignatureProperties) -> Result<SigningKey> {
debug!("Loading notary server's signing key");

185
src/service.rs Normal file
View File

@@ -0,0 +1,185 @@
pub mod axum_websocket;
pub mod tcp;
pub mod websocket;
use async_trait::async_trait;
use axum::{
extract::{rejection::JsonRejection, FromRequestParts, State},
http::{header, request::Parts, HeaderMap, StatusCode},
response::{IntoResponse, Json, Response},
};
use axum_macros::debug_handler;
use p256::ecdsa::{Signature, SigningKey};
use tlsn_notary::{bind_notary, NotaryConfig};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::{debug, error, info, trace};
use uuid::Uuid;
use crate::{
domain::notary::{NotarizationSessionRequest, NotarizationSessionResponse, NotaryGlobals},
error::NotaryServerError,
service::{
axum_websocket::{header_eq, WebSocketUpgrade},
tcp::{tcp_notarize, TcpUpgrade},
websocket::websocket_notarize,
},
};
/// A wrapper enum to facilitate extracting TCP connection for either WebSocket or TCP clients,
/// so that we can use a single endpoint and handler for notarization for both types of clients
pub enum ProtocolUpgrade {
Tcp(TcpUpgrade),
Ws(WebSocketUpgrade),
}
#[async_trait]
impl<S> FromRequestParts<S> for ProtocolUpgrade
where
S: Send + Sync,
{
type Rejection = NotaryServerError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Extract tcp connection for websocket client
if header_eq(&parts.headers, header::UPGRADE, "websocket") {
let extractor = WebSocketUpgrade::from_request_parts(parts, state)
.await
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
return Ok(Self::Ws(extractor));
// Extract tcp connection for tcp client
} else if header_eq(&parts.headers, header::UPGRADE, "tcp") {
let extractor = TcpUpgrade::from_request_parts(parts, state)
.await
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
return Ok(Self::Tcp(extractor));
} else {
return Err(NotaryServerError::BadProverRequest(
"Upgrade header is not set for client".to_string(),
));
}
}
}
/// Handler to upgrade protocol from http to either websocket or underlying tcp depending on the type of client
/// the session_id header is also extracted here to fetch the configuration parameters
/// that have been submitted in the previous request to /session made by the same client
pub async fn upgrade_protocol(
protocol_upgrade: ProtocolUpgrade,
mut headers: HeaderMap,
State(notary_globals): State<NotaryGlobals>,
) -> Response {
info!("Received upgrade protocol request");
// Extract the session_id from the headers
let session_id = match headers.remove("X-Session-Id") {
Some(session_id) => match session_id.to_str() {
Ok(session_id) => session_id.to_string(),
Err(err) => {
let err_msg = format!("X-Session-Id header submitted is not a string: {}", err);
error!(err_msg);
return NotaryServerError::BadProverRequest(err_msg).into_response();
}
},
None => {
let err_msg = "Missing X-Session-Id in upgrade protocol request".to_string();
error!(err_msg);
return NotaryServerError::BadProverRequest(err_msg).into_response();
}
};
// Fetch the configuration data from the store using the session_id
let max_transcript_size = match notary_globals.store.lock().await.get(&session_id) {
Some(max_transcript_size) => max_transcript_size.to_owned(),
None => {
let err_msg = format!("Session id {} does not exist", session_id);
error!(err_msg);
return NotaryServerError::BadProverRequest(err_msg).into_response();
}
};
// This completes the HTTP Upgrade request and returns a successful response to the client, meanwhile initiating the websocket or tcp connection
match protocol_upgrade {
ProtocolUpgrade::Ws(ws) => ws.on_upgrade(move |socket| {
websocket_notarize(socket, notary_globals, session_id, max_transcript_size)
}),
ProtocolUpgrade::Tcp(tcp) => tcp.on_upgrade(move |stream| {
tcp_notarize(stream, notary_globals, session_id, max_transcript_size)
}),
}
}
/// Handler to initialize and configure notarization for both TCP and WebSocket clients
#[debug_handler(state = NotaryGlobals)]
pub async fn initialize(
State(notary_globals): State<NotaryGlobals>,
payload: Result<Json<NotarizationSessionRequest>, JsonRejection>,
) -> impl IntoResponse {
info!(
?payload,
"Received request for initializing a notarization session"
);
// Parse the body payload
let payload = match payload {
Ok(payload) => payload,
Err(err) => {
error!("Malformed payload submitted for initializing notarization: {err}");
return NotaryServerError::BadProverRequest(err.to_string()).into_response();
}
};
// Ensure that the max_transcript_size submitted is not larger than the global max limit configured in notary server
if payload.max_transcript_size > Some(notary_globals.notarization_config.max_transcript_size) {
error!(
"Max transcript size requested {:?} exceeds the maximum threshold {:?}",
payload.max_transcript_size, notary_globals.notarization_config.max_transcript_size
);
return NotaryServerError::BadProverRequest(
"Max transcript size requested exceeds the maximum threshold".to_string(),
)
.into_response();
}
let prover_session_id = Uuid::new_v4().to_string();
// Store the configuration data in a temporary store
notary_globals
.store
.lock()
.await
.insert(prover_session_id.clone(), payload.max_transcript_size);
trace!("Latest store state: {:?}", notary_globals.store);
// Return the session id in the response to the client
(
StatusCode::OK,
Json(NotarizationSessionResponse {
session_id: prover_session_id,
}),
)
.into_response()
}
/// Run the notarization
pub async fn notary_service<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
socket: T,
signing_key: &SigningKey,
session_id: &str,
max_transcript_size: Option<usize>,
) -> Result<(), NotaryServerError> {
debug!(?session_id, "Starting notarization...");
let mut config_builder = NotaryConfig::builder();
config_builder.id(session_id);
if let Some(max_transcript_size) = max_transcript_size {
config_builder.max_transcript_size(max_transcript_size);
}
let config = config_builder.build()?;
let (notary, notary_fut) = bind_notary(config, socket.compat())?;
// Run the notary and background processes concurrently
tokio::try_join!(notary_fut, notary.notarize::<Signature>(signing_key),).map(|_| Ok(()))?
}

View File

@@ -0,0 +1,913 @@
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.6.19/axum/src/extract/ws.rs
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other modification is commented with the prefix "NOTARY_MODIFICATION:"
//!
//! The code is under the following license:
//!
//! Copyright (c) 2019 Axum Contributors
//!
//! Permission is hereby granted, free of charge, to any
//! person obtaining a copy of this software and associated
//! documentation files (the "Software"), to deal in the
//! Software without restriction, including without
//! limitation the rights to use, copy, modify, merge,
//! publish, distribute, sublicense, and/or sell copies of
//! the Software, and to permit persons to whom the Software
//! is furnished to do so, subject to the following
//! conditions:
//!
//! The above copyright notice and this permission notice
//! shall be included in all copies or substantial portions
//! of the Software.
//!
//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
//! ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
//! TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
//! PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
//! SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
//! CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
//! OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
//! IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
//! DEALINGS IN THE SOFTWARE.
//!
//!
//! Handle WebSocket connections.
//!
//! # Example
//!
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
//! }
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! while let Some(msg) = socket.recv().await {
//! let msg = if let Ok(msg) = msg {
//! msg
//! } else {
//! // client disconnected
//! return;
//! };
//!
//! if socket.send(msg).await.is_err() {
//! // client disconnected
//! return;
//! }
//! }
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! # Passing data and/or state to an `on_upgrade` callback
//!
//! ```
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get,
//! Router,
//! };
//!
//! #[derive(Clone)]
//! struct AppState {
//! // ...
//! }
//!
//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
//! ws.on_upgrade(|socket| handle_socket(socket, state))
//! }
//!
//! async fn handle_socket(socket: WebSocket, state: AppState) {
//! // ...
//! }
//!
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .with_state(AppState { /* ... */ });
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! # Read and write concurrently
//!
//! If you need to read and write concurrently from a [`WebSocket`] you can use
//! [`StreamExt::split`]:
//!
//! ```rust,no_run
//! use axum::{Error, extract::ws::{WebSocket, Message}};
//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! let (mut sender, mut receiver) = socket.split();
//!
//! tokio::spawn(write(sender));
//! tokio::spawn(read(receiver));
//! }
//!
//! async fn read(receiver: SplitStream<WebSocket>) {
//! // ...
//! }
//!
//! async fn write(sender: SplitSink<WebSocket, Message>) {
//! // ...
//! }
//! ```
//!
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
#![allow(unused)]
use self::rejection::*;
use async_trait::async_trait;
use async_tungstenite::{
tokio::TokioAdapter,
tungstenite::{
self as ts,
protocol::{self, WebSocketConfig},
},
WebSocketStream,
};
use axum::{
body::{self, Bytes},
extract::FromRequestParts,
response::Response,
Error,
};
use futures_util::{
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
};
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tracing::error;
/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
/// rejected.
///
/// See the [module docs](self) for an example.
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
}
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("config", &self.config)
.field("protocol", &self.protocol)
.field("sec_websocket_key", &self.sec_websocket_key)
.field("sec_websocket_protocol", &self.sec_websocket_protocol)
.finish_non_exhaustive()
}
}
impl<F> WebSocketUpgrade<F> {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
self
}
/// Set the maximum message size (defaults to 64 megabytes)
pub fn max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
/// Set the maximum frame size (defaults to 16 megabytes)
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
/// Allow server to accept unmasked frames (defaults to false)
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
self.config.accept_unmasked_frames = accept;
self
}
/// Set the known protocols.
///
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
/// return the protocol name.
///
/// The protocols should be listed in decreasing order of preference: if the client offers
/// multiple protocols that the server could support, the server will pick the first one in
/// this list.
///
/// # Examples
///
/// ```
/// use axum::{
/// extract::ws::{WebSocketUpgrade, WebSocket},
/// routing::get,
/// response::{IntoResponse, Response},
/// Router,
/// };
///
/// let app = Router::new().route("/ws", get(handler));
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
/// .on_upgrade(|socket| async {
/// // ...
/// })
/// }
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
if let Some(req_protocols) = self
.sec_websocket_protocol
.as_ref()
.and_then(|p| p.to_str().ok())
{
self.protocol = protocols
.into_iter()
// FIXME: This will often allocate a new `String` and so is less efficient than it
// could be. But that can't be fixed without breaking changes to the public API.
.map(Into::into)
.find(|protocol| {
req_protocols
.split(',')
.any(|req_protocol| req_protocol.trim() == protocol)
})
.map(|protocol| match protocol {
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
Cow::Borrowed(s) => HeaderValue::from_static(s),
});
}
self
}
/// Provide a callback to call if upgrading the connection fails.
///
/// The connection upgrade is performed in a background task. If that fails this callback
/// will be called.
///
/// By default any errors will be silently ignored.
///
/// # Example
///
/// ```
/// use axum::{
/// extract::{WebSocketUpgrade},
/// response::Response,
/// };
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.on_failed_upgrade(|error| {
/// report_error(error);
/// })
/// .on_upgrade(|socket| async { /* ... */ })
/// }
/// #
/// # fn report_error(_: axum::Error) {}
/// ```
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
where
C: OnFailedUpdgrade,
{
WebSocketUpgrade {
config: self.config,
protocol: self.protocol,
sec_websocket_key: self.sec_websocket_key,
on_upgrade: self.on_upgrade,
on_failed_upgrade: callback,
sec_websocket_protocol: self.sec_websocket_protocol,
}
}
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
#[must_use = "to setup the WebSocket connection, this response must be returned"]
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
C: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
F: OnFailedUpdgrade,
{
let on_upgrade = self.on_upgrade;
let config = self.config;
let on_failed_upgrade = self.on_failed_upgrade;
let protocol = self.protocol.clone();
tokio::spawn(async move {
let upgraded = match on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
error!("Something wrong with on_upgrade: {:?}", err);
on_failed_upgrade.call(Error::new(err));
return;
}
};
let socket = WebSocketStream::from_raw_socket(
// NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite
TokioAdapter::new(upgraded),
protocol::Role::Server,
Some(config),
)
.await;
let socket = WebSocket {
inner: socket,
protocol,
};
callback(socket).await;
});
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(body::boxed(body::Empty::new())).unwrap()
}
}
/// What to do when a connection upgrade fails.
///
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
pub trait OnFailedUpdgrade: Send + 'static {
/// Call the callback.
fn call(self, error: Error);
}
impl<F> OnFailedUpdgrade for F
where
F: FnOnce(Error) + Send + 'static,
{
fn call(self, error: Error) {
self(error)
}
}
/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
///
/// It simply ignores the error.
#[non_exhaustive]
#[derive(Debug)]
pub struct DefaultOnFailedUpdgrade;
impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
#[inline]
fn call(self, _error: Error) {}
}
#[async_trait]
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
where
S: Send + Sync,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}
let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();
let on_upgrade = parts
.extensions
.remove::<OnUpgrade>()
.ok_or(ConnectionNotUpgradable)?;
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self {
config: Default::default(),
protocol: None,
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
on_failed_upgrade: DefaultOnFailedUpdgrade,
})
}
}
pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = headers.get(&key) {
header
} else {
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}
/// A stream of WebSocket messages.
///
/// See [the module level documentation](self) for more details.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<TokioAdapter<Upgraded>>,
protocol: Option<HeaderValue>,
}
impl WebSocket {
/// Consume `self` and get the inner [`async_tungstenite::WebSocketStream`].
pub fn into_inner(self) -> WebSocketStream<TokioAdapter<Upgraded>> {
self.inner
}
/// Receive another message.
///
/// Returns `None` if the stream has closed.
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
self.next().await
}
/// Send a message.
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
self.inner
.send(msg.into_tungstenite())
.await
.map_err(Error::new)
}
/// Gracefully close this WebSocket.
pub async fn close(mut self) -> Result<(), Error> {
self.inner.close(None).await.map_err(Error::new)
}
/// Return the selected WebSocket subprotocol, if one has been chosen.
pub fn protocol(&self) -> Option<&HeaderValue> {
self.protocol.as_ref()
}
}
impl Stream for WebSocket {
type Item = Result<Message, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
Some(Ok(msg)) => {
if let Some(msg) = Message::from_tungstenite(msg) {
return Poll::Ready(Some(Ok(msg)));
}
}
Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
None => return Poll::Ready(None),
}
}
}
}
impl Sink<Message> for WebSocket {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.inner)
.start_send(item.into_tungstenite())
.map_err(Error::new)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
}
}
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
pub type CloseCode = u16;
/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CloseFrame<'t> {
/// The reason as a code.
pub code: CloseCode,
/// The reason as text string.
pub reason: Cow<'t, str>,
}
/// A WebSocket message.
//
// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
// Copyright (c) 2017 Alexey Galakhov
// Copyright (c) 2016 Jason Housley
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
/// A text WebSocket message
Text(String),
/// A binary WebSocket message
Binary(Vec<u8>),
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Ping messages will be automatically responded to by the server, so you do not have to worry
/// about dealing with them yourself.
Ping(Vec<u8>),
/// A pong message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Pong messages will be automatically sent to the client if a ping message is received, so
/// you do not have to worry about constructing them yourself unless you want to implement a
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
Pong(Vec<u8>),
/// A close message with the optional close frame.
Close(Option<CloseFrame<'static>>),
}
impl Message {
fn into_tungstenite(self) -> ts::Message {
match self {
Self::Text(text) => ts::Message::Text(text),
Self::Binary(binary) => ts::Message::Binary(binary),
Self::Ping(ping) => ts::Message::Ping(ping),
Self::Pong(pong) => ts::Message::Pong(pong),
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
code: ts::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
})),
Self::Close(None) => ts::Message::Close(None),
}
}
fn from_tungstenite(message: ts::Message) -> Option<Self> {
match message {
ts::Message::Text(text) => Some(Self::Text(text)),
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
}))),
ts::Message::Close(None) => Some(Self::Close(None)),
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
// https://github.com/snapview/tungstenite-rs/issues/268
ts::Message::Frame(_) => None,
}
}
/// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> {
match self {
Self::Text(string) => string.into_bytes(),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
Self::Close(None) => Vec::new(),
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
}
}
/// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String, Error> {
match self {
Self::Text(string) => Ok(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
.map_err(|err| err.utf8_error())
.map_err(Error::new)?),
Self::Close(None) => Ok(String::new()),
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
}
}
/// Attempt to get a &str from the WebSocket message,
/// this will try to convert binary data to utf8.
pub fn to_text(&self) -> Result<&str, Error> {
match *self {
Self::Text(ref string) => Ok(string),
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
Ok(std::str::from_utf8(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(""),
Self::Close(Some(ref frame)) => Ok(&frame.reason),
}
}
}
impl From<String> for Message {
fn from(string: String) -> Self {
Message::Text(string)
}
}
impl<'s> From<&'s str> for Message {
fn from(string: &'s str) -> Self {
Message::Text(string.into())
}
}
impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Self {
Message::Binary(data.into())
}
}
impl From<Vec<u8>> for Message {
fn from(data: Vec<u8>) -> Self {
Message::Binary(data)
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_data()
}
}
fn sign(key: &[u8]) -> HeaderValue {
use base64::engine::Engine as _;
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
pub mod rejection {
//! WebSocket specific rejections.
use axum_core::__composite_rejection as composite_rejection;
use axum_core::__define_rejection as define_rejection;
define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `GET`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotGet;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidConnectionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Upgrade` header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidUpgradeHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidWebSocketVersionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Key` header missing"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct WebSocketKeyHeaderMissing;
}
define_rejection! {
#[status = UPGRADE_REQUIRED]
#[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
///
/// This rejection is returned if the connection cannot be upgraded for example if the
/// request is HTTP/1.0.
///
/// See [MDN] for more details about connection upgrades.
///
/// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
pub struct ConnectionNotUpgradable;
}
composite_rejection! {
/// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
///
/// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
ConnectionNotUpgradable,
}
}
}
pub mod close_code {
//! Constants for [`CloseCode`]s.
//!
//! [`CloseCode`]: super::CloseCode
/// Indicates a normal closure, meaning that the purpose for which the connection was
/// established has been fulfilled.
pub const NORMAL: u16 = 1000;
/// Indicates that an endpoint is "going away", such as a server going down or a browser having
/// navigated away from a page.
pub const AWAY: u16 = 1001;
/// Indicates that an endpoint is terminating the connection due to a protocol error.
pub const PROTOCOL: u16 = 1002;
/// Indicates that an endpoint is terminating the connection because it has received a type of
/// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
/// it receives a binary message).
pub const UNSUPPORTED: u16 = 1003;
/// Indicates that no status code was included in a closing frame.
pub const STATUS: u16 = 1005;
/// Indicates an abnormal closure.
pub const ABNORMAL: u16 = 1006;
/// Indicates that an endpoint is terminating the connection because it has received data
/// within a message that was not consistent with the type of the message (e.g., non-UTF-8
/// RFC3629 data within a text message).
pub const INVALID: u16 = 1007;
/// Indicates that an endpoint is terminating the connection because it has received a message
/// that violates its policy. This is a generic status code that can be returned when there is
/// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
/// hide specific details about the policy.
pub const POLICY: u16 = 1008;
/// Indicates that an endpoint is terminating the connection because it has received a message
/// that is too big for it to process.
pub const SIZE: u16 = 1009;
/// Indicates that an endpoint (client) is terminating the connection because it has expected
/// the server to negotiate one or more extension, but the server didn't return them in the
/// response message of the WebSocket handshake. The list of extensions that are needed should
/// be given as the reason for closing. Note that this status code is not used by the server,
/// because it can fail the WebSocket handshake instead.
pub const EXTENSION: u16 = 1010;
/// Indicates that a server is terminating the connection because it encountered an unexpected
/// condition that prevented it from fulfilling the request.
pub const ERROR: u16 = 1011;
/// Indicates that the server is restarting.
pub const RESTART: u16 = 1012;
/// Indicates that the server is overloaded and the client should either connect to a different
/// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
/// action.
pub const AGAIN: u16 = 1013;
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, routing::get, Router};
use http::{Request, Version};
use tower::ServiceExt;
#[tokio::test]
async fn rejects_http_1_0_requests() {
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let rejection = ws.unwrap_err();
assert!(matches!(
rejection,
WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
));
std::future::ready(())
});
let req = Request::builder()
.version(Version::HTTP_10)
.method(Method::GET)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
.header("sec-websocket-version", "13")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[allow(dead_code)]
fn default_on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
#[allow(dead_code)]
fn on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
}

101
src/service/tcp.rs Normal file
View File

@@ -0,0 +1,101 @@
use async_trait::async_trait;
use axum::{
body,
extract::FromRequestParts,
http::{header, request::Parts, HeaderValue, StatusCode},
response::Response,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
use std::future::Future;
use tracing::{debug, error, info};
use crate::{domain::notary::NotaryGlobals, service::notary_service, NotaryServerError};
/// Custom extractor used to extract underlying TCP connection for TCP client — using the same upgrade primitives used by
/// the WebSocket implementation where the underlying TCP connection (wrapped in an Upgraded object) only gets polled as an OnUpgrade future
/// after the ongoing HTTP request is finished (ref: https://github.com/tokio-rs/axum/blob/a6a849bb5b96a2f641fa077fe76f70ad4d20341c/axum/src/extract/ws.rs#L122)
///
/// More info on the upgrade primitives: https://docs.rs/hyper/latest/hyper/upgrade/index.html
pub struct TcpUpgrade {
pub on_upgrade: OnUpgrade,
}
#[async_trait]
impl<S> FromRequestParts<S> for TcpUpgrade
where
S: Send + Sync,
{
type Rejection = NotaryServerError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let on_upgrade =
parts
.extensions
.remove::<OnUpgrade>()
.ok_or(NotaryServerError::BadProverRequest(
"Upgrade header is not set for TCP client".to_string(),
))?;
Ok(Self { on_upgrade })
}
}
impl TcpUpgrade {
/// Utility function to complete the http upgrade protocol by
/// (1) Return 101 switching protocol response to client to indicate the switching to TCP
/// (2) Spawn a new thread to await on the OnUpgrade object to claim the underlying TCP connection
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
C: FnOnce(Upgraded) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let on_upgrade = self.on_upgrade;
tokio::spawn(async move {
let upgraded = match on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
error!("Something wrong with upgrading HTTP: {:?}", err);
return;
}
};
callback(upgraded).await;
});
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const TCP: HeaderValue = HeaderValue::from_static("tcp");
let builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, TCP);
builder.body(body::boxed(body::Empty::new())).unwrap()
}
}
/// Perform notarization using the extracted tcp connection
pub async fn tcp_notarize(
stream: Upgraded,
notary_globals: NotaryGlobals,
session_id: String,
max_transcript_size: Option<usize>,
) {
debug!(?session_id, "Upgraded to tcp connection");
match notary_service(
stream,
&notary_globals.notary_signing_key,
&session_id,
max_transcript_size,
)
.await
{
Ok(_) => {
info!(?session_id, "Successful notarization using tcp!");
}
Err(err) => {
error!(?session_id, "Failed notarization using tcp: {err}");
}
}
}

34
src/service/websocket.rs Normal file
View File

@@ -0,0 +1,34 @@
use tracing::{debug, error, info};
use ws_stream_tungstenite::WsStream;
use crate::{
domain::notary::NotaryGlobals,
service::{axum_websocket::WebSocket, notary_service},
};
/// Perform notarization using the established websocket connection
pub async fn websocket_notarize(
socket: WebSocket,
notary_globals: NotaryGlobals,
session_id: String,
max_transcript_size: Option<usize>,
) {
debug!(?session_id, "Upgraded to websocket connection");
// Wrap the websocket in WsStream so that we have AsyncRead and AsyncWrite implemented
let stream = WsStream::new(socket.into_inner());
match notary_service(
stream,
&notary_globals.notary_signing_key,
&session_id,
max_transcript_size,
)
.await
{
Ok(_) => {
info!(?session_id, "Successful notarization using websocket!");
}
Err(err) => {
error!(?session_id, "Failed notarization using websocket: {err}");
}
}
}

View File

@@ -17,7 +17,7 @@ mod test {
#[test]
fn test_parse_config_file() {
let location = "./src/config/config.yaml";
let location = "./config.yaml";
let config: Result<NotaryServerProperties> = parse_config_file(location);
assert!(
config.is_ok(),

View File

@@ -1,5 +1,13 @@
use async_tungstenite::{
tokio::connect_async_with_tls_connector_and_config, tungstenite::protocol::WebSocketConfig,
};
use futures::AsyncWriteExt;
use hyper::{body::to_bytes, Body, Request, StatusCode};
use hyper::{
body::to_bytes,
client::{conn::Parts, HttpConnector},
Body, Client, Request, StatusCode,
};
use hyper_tls::HttpsConnector;
use rustls::{Certificate, ClientConfig, RootCertStore};
use std::{
net::{IpAddr, SocketAddr},
@@ -11,21 +19,26 @@ use tlsn_prover::{bind_prover, ProverConfig};
use tokio_rustls::TlsConnector;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::debug;
use ws_stream_tungstenite::WsStream;
use notary_server::{
read_pem_file, run_tcp_server, NotaryServerProperties, NotarySignatureProperties,
read_pem_file, run_server, NotarizationProperties, NotarizationSessionRequest,
NotarizationSessionResponse, NotaryServerProperties, NotarySignatureProperties,
ServerProperties, TLSSignatureProperties, TracingProperties,
};
const NOTARY_CA_CERT_PATH: &str = "./src/fixture/tls/rootCA.crt";
const NOTARY_CA_CERT_BYTES: &[u8] = include_bytes!("../src/fixture/tls/rootCA.crt");
#[tokio::test]
async fn test_notarization() {
async fn setup_config_and_server(sleep_ms: u64, port: u16) -> NotaryServerProperties {
let notary_config = NotaryServerProperties {
server: ServerProperties {
name: "tlsnotaryserver.io".to_string(),
domain: "127.0.0.1".to_string(),
port: 7047,
port,
},
notarization: NotarizationProperties {
max_transcript_size: 1 << 14,
},
tls_signature: TLSSignatureProperties {
private_key_pem_path: "./src/fixture/tls/notary.key".to_string(),
@@ -39,24 +52,26 @@ async fn test_notarization() {
},
};
tracing_subscriber::fmt::init();
let _ = tracing_subscriber::fmt::try_init();
let config = notary_config.clone();
// Run the the notary server
// Run the notary server
tokio::spawn(async move {
run_tcp_server(&config).await.unwrap();
run_server(&config).await.unwrap();
});
// Sleep for a while to allow notary server to finish set up and start listening
tokio::time::sleep(Duration::from_millis(100)).await;
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
// Run the prover
run_prover(&notary_config).await;
notary_config
}
#[tracing::instrument(skip(notary_config))]
async fn run_prover(notary_config: &NotaryServerProperties) {
#[tokio::test]
async fn test_tcp_prover() {
// Notary server configuration setup
let notary_config = setup_config_and_server(100, 7048).await;
// Connect to the Notary via TLS-TCP
let mut certificate_file_reader = read_pem_file(NOTARY_CA_CERT_PATH).await.unwrap();
let mut certificates: Vec<Certificate> = rustls_pemfile::certs(&mut certificate_file_reader)
@@ -75,14 +90,15 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
.with_no_client_auth();
let notary_connector = TlsConnector::from(Arc::new(client_notary_config));
let notary_domain = notary_config.server.domain.clone();
let notary_port = notary_config.server.port;
let notary_socket = tokio::net::TcpStream::connect(SocketAddr::new(
IpAddr::V4(notary_config.server.domain.parse().unwrap()),
notary_config.server.port,
IpAddr::V4(notary_domain.parse().unwrap()),
notary_port,
))
.await
.unwrap();
let prover_address = notary_socket.local_addr().unwrap().to_string();
let notary_tls_socket = notary_connector
.connect(
notary_config.server.name.as_str().try_into().unwrap(),
@@ -91,6 +107,77 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
.await
.unwrap();
// Attach the hyper HTTP client to the notary TLS connection to send request to the /session endpoint to configure notarization and obtain session id
let (mut request_sender, connection) = hyper::client::conn::handshake(notary_tls_socket)
.await
.unwrap();
// Spawn the HTTP task to be run concurrently
let connection_task = tokio::spawn(connection.without_shutdown());
// Build the HTTP request to configure notarization
let payload = serde_json::to_string(&NotarizationSessionRequest {
client_type: notary_server::ClientType::Tcp,
max_transcript_size: Some(notary_config.notarization.max_transcript_size),
})
.unwrap();
let request = Request::builder()
.uri(format!("https://{notary_domain}:{notary_port}/session"))
.method("POST")
.header("Host", notary_domain.clone())
// Need to specify application/json for axum to parse it as json
.header("Content-Type", "application/json")
.body(Body::from(payload))
.unwrap();
debug!("Sending configuration request");
let response = request_sender.send_request(request).await.unwrap();
debug!("Sent configuration request");
assert!(response.status() == StatusCode::OK);
debug!("Response OK");
// Pretty printing :)
let payload = to_bytes(response.into_body()).await.unwrap().to_vec();
let notarization_response =
serde_json::from_str::<NotarizationSessionResponse>(&String::from_utf8_lossy(&payload))
.unwrap();
debug!("Notarization response: {:?}", notarization_response,);
// Send notarization request via HTTP, where the underlying TCP connection will be extracted later
let request = Request::builder()
.uri(format!("https://{notary_domain}:{notary_port}/notarize"))
.method("GET")
.header("Host", notary_domain)
.header("Connection", "Upgrade")
// Need to specify this upgrade header for server to extract tcp connection later
.header("Upgrade", "TCP")
// Need to specify the session_id so that notary server knows the right configuration to use
// as the configuration is set in the previous HTTP call
.header("X-Session-Id", notarization_response.session_id.clone())
.body(Body::empty())
.unwrap();
debug!("Sending notarization request");
let response = request_sender.send_request(request).await.unwrap();
debug!("Sent notarization request");
assert!(response.status() == StatusCode::SWITCHING_PROTOCOLS);
debug!("Switched protocol OK");
// Claim back the TCP socket after HTTP exchange is done so that client can use it for notarization
let Parts {
io: notary_tls_socket,
..
} = connection_task.await.unwrap().unwrap();
// Connect to the Server
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
let server_task = tokio::spawn(bind_test_server(server_socket.compat()));
@@ -100,9 +187,9 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
// Basic default prover config — use local address as the notarization session id
// Basic default prover config — use the responded session id from notary server
let prover_config = ProverConfig::builder()
.id(prover_address)
.id(notarization_response.session_id)
.server_dns(SERVER_DOMAIN)
.root_cert_store(root_store)
.build()
@@ -167,3 +254,169 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
debug!("Done notarization!");
}
#[tokio::test]
async fn test_websocket_prover() {
// Notary server configuration setup
let notary_config = setup_config_and_server(100, 7049).await;
let notary_domain = notary_config.server.domain.clone();
let notary_port = notary_config.server.port;
// Connect to the notary server via TLS-WebSocket
// Try to avoid dealing with transport layer directly to mimic the limitation of a browser extension that uses websocket
//
// Establish TLS setup for connections later
let certificate =
tokio_native_tls::native_tls::Certificate::from_pem(NOTARY_CA_CERT_BYTES).unwrap();
let notary_tls_connector = tokio_native_tls::native_tls::TlsConnector::builder()
.add_root_certificate(certificate)
.use_sni(false)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
// Call the /session HTTP API to configure notarization and obtain session id
let mut hyper_http_connector = HttpConnector::new();
hyper_http_connector.enforce_http(false);
let mut hyper_tls_connector =
HttpsConnector::from((hyper_http_connector, notary_tls_connector.clone().into()));
hyper_tls_connector.https_only(true);
let https_client = Client::builder().build::<_, hyper::Body>(hyper_tls_connector);
// Build the HTTP request to configure notarization
let payload = serde_json::to_string(&NotarizationSessionRequest {
client_type: notary_server::ClientType::Websocket,
max_transcript_size: Some(notary_config.notarization.max_transcript_size),
})
.unwrap();
let request = Request::builder()
.uri(format!("https://{notary_domain}:{notary_port}/session"))
.method("POST")
.header("Host", notary_domain.clone())
// Need to specify application/json for axum to parse it as json
.header("Content-Type", "application/json")
.body(Body::from(payload))
.unwrap();
debug!("Sending request");
let response = https_client.request(request).await.unwrap();
debug!("Sent request");
assert!(response.status() == StatusCode::OK);
debug!("Response OK");
// Pretty printing :)
let payload = to_bytes(response.into_body()).await.unwrap().to_vec();
let notarization_response =
serde_json::from_str::<NotarizationSessionResponse>(&String::from_utf8_lossy(&payload))
.unwrap();
debug!("Notarization response: {:?}", notarization_response,);
// Connect to the Notary via TLS-Websocket
//
// Note: This will establish a new TLS-TCP connection instead of reusing the previous TCP connection
// used in the previous HTTP POST request because we cannot claim back the tcp connection used in hyper
// client while using its high level request function — there does not seem to have a crate that can let you
// make a request without establishing TCP connection where you can claim the TCP connection later after making the request
let request = http::Request::builder()
.uri(format!("wss://{notary_domain}:{notary_port}/notarize"))
.header("Host", notary_domain.clone())
.header("Sec-WebSocket-Key", uuid::Uuid::new_v4().to_string())
.header("Sec-WebSocket-Version", "13")
.header("Connection", "Upgrade")
.header("Upgrade", "Websocket")
// Need to specify the session_id so that notary server knows the right configuration to use
// as the configuration is set in the previous HTTP call
.header("X-Session-Id", notarization_response.session_id.clone())
.body(())
.unwrap();
let (notary_ws_stream, _) = connect_async_with_tls_connector_and_config(
request,
Some(notary_tls_connector.into()),
Some(WebSocketConfig::default()),
)
.await
.unwrap();
// Wrap the socket with the adapter so that we get AsyncRead and AsyncWrite implemented
let notary_ws_socket = WsStream::new(notary_ws_stream);
// Connect to the Server
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
let server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
// Basic default prover config — use the responded session id from notary server
let prover_config = ProverConfig::builder()
.id(notarization_response.session_id)
.server_dns(SERVER_DOMAIN)
.root_cert_store(root_store)
.build()
.unwrap();
// Bind the Prover to the sockets
let (tls_connection, prover_fut, mux_fut) =
bind_prover(prover_config, client_socket.compat(), notary_ws_socket)
.await
.unwrap();
// Spawn the Prover and Mux tasks to be run concurrently
tokio::spawn(mux_fut);
let prover_task = tokio::spawn(prover_fut);
let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat())
.await
.unwrap();
let connection_task = tokio::spawn(connection.without_shutdown());
let request = Request::builder()
.uri(format!("https://{}/echo", SERVER_DOMAIN))
.header("Host", SERVER_DOMAIN)
.header("Connection", "close")
.method("POST")
.body(Body::from("echo"))
.unwrap();
debug!("Sending request to server: {:?}", request);
let response = request_sender.send_request(request).await.unwrap();
assert!(response.status() == StatusCode::OK);
debug!(
"Received response from server: {:?}",
String::from_utf8_lossy(&to_bytes(response.into_body()).await.unwrap())
);
let mut server_tls_conn = server_task.await.unwrap().unwrap();
// Make sure the server closes cleanly (sends close notify)
server_tls_conn.close().await.unwrap();
let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner();
client_socket.close().await.unwrap();
let mut prover = prover_task.await.unwrap().unwrap();
let sent_len = prover.sent_transcript().data().len();
let recv_len = prover.recv_transcript().data().len();
prover.add_commitment_sent(0..sent_len as u32).unwrap();
prover.add_commitment_recv(0..recv_len as u32).unwrap();
_ = prover.finalize().await.unwrap();
debug!("Done notarization!");
}