mirror of
https://github.com/tlsnotary/notary-server.git
synced 2026-01-10 06:57:59 -05:00
Ns 2 http websocket (#6)
* Setup repo. * Correct typo. * Add tcp listener with tls. * Hook up notary service. * Fix notary signing key loading. * Add otel tracing. * Add test to test prover and notary integration. * Fix notarization test, add comments. * Fix github action. * Fix span logging, github actions. * Add logic to promote to http and then downgrade to tcp for notarization. * Fix client hang issue * Change channel message type. * Fix response parsing from notary. * Fix websocket implementation and use upgrade protocol for raw tcp. * Modify test to mimick browser extension for websocket test. * Refactor tcp client handling. * Add global store for persistent data. * Finish websocket handler and test. * Add comments. * Add more comments and documentation. * Add openapi.yaml. * Modify README. * Add architecture explanation. * Modify README. * Fix PR based on comments. * Combine tcp and websocket extractors. * Refactor and fix documentations.
This commit is contained in:
committed by
GitHub
parent
0e9fadce01
commit
3fcd517c7f
21
Cargo.toml
21
Cargo.toml
@@ -6,10 +6,17 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.67"
|
||||
async-tungstenite = { version = "0.22.2", features = ["tokio-runtime", "tokio-native-tls"] }
|
||||
axum = { version = "0.6.18", features = ["ws"]}
|
||||
axum-core = "0.3.4"
|
||||
axum-macros = "0.3.8"
|
||||
base64 = "0.21.0"
|
||||
eyre = "0.6.8"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3.28"
|
||||
hyper = { version = "0.14", features = ["client", "http1"] }
|
||||
http = "0.2.9"
|
||||
hyper = { version = "0.14", features = ["client", "http1", "server", "tcp"] }
|
||||
opentelemetry = { version = "0.19" }
|
||||
p256 = "0.13"
|
||||
rustls = { version = "0.21" }
|
||||
@@ -17,15 +24,23 @@ rustls-pemfile = { version = "1.0.2" }
|
||||
serde = { version = "1.0.147", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
serde_yaml = "0.9.21"
|
||||
sha1 = "0.10"
|
||||
structopt = "0.3.26"
|
||||
thiserror = "1"
|
||||
tls-server-fixture = { git = "https://github.com/tlsnotary/tlsn" }
|
||||
tlsn-notary = { git = "https://github.com/tlsnotary/tlsn" }
|
||||
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn" }
|
||||
tlsn-tls-core = { git = "https://github.com/tlsnotary/tlsn" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-rustls = { version = "0.24.1" }
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
tower = { version = "0.4.12", features = ["make"] }
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.19"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.4.1", features = ["v4", "fast-rng"] }
|
||||
ws_stream_tungstenite = { version = "0.10.0", features = ["tokio_io"] }
|
||||
|
||||
[dev-dependencies]
|
||||
hyper-tls = "0.5.0"
|
||||
tls-server-fixture = { git = "https://github.com/tlsnotary/tlsn" }
|
||||
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn" }
|
||||
tokio-native-tls = "0.3.1"
|
||||
|
||||
50
README.md
50
README.md
@@ -8,3 +8,53 @@ An implementation of the notary server in Rust.
|
||||
## ⚠️ Notice
|
||||
|
||||
This project is currently under active development and should not be used in production. Expect bugs and regular major breaking changes.
|
||||
|
||||
---
|
||||
## Running the server
|
||||
1. Configure the server setting in this [file](./config.yaml) — refer [here](./src/config.rs) for more information on the definition of the setting parameters.
|
||||
2. Start the server by running following in a terminal at the top level of this project.
|
||||
```bash
|
||||
cargo run
|
||||
```
|
||||
3. To use a config file from a different location, run the following command to override the default config file location.
|
||||
```bash
|
||||
cargo run -- --config-file <path-to-new-config-file>
|
||||
```
|
||||
|
||||
---
|
||||
## API
|
||||
All APIs are TLS-protected, hence please use `https://` or `wss://`.
|
||||
### HTTP APIs
|
||||
Defined in the [OpenAPI specification](./openapi.yaml).
|
||||
|
||||
### WebSocket APIs
|
||||
#### /notarize
|
||||
##### Description
|
||||
To perform notarization using the session id (unique id returned upon calling the `/session` endpoint successfully) submitted as a custom header.
|
||||
|
||||
##### Custom Header
|
||||
`X-Session-Id`
|
||||
|
||||
##### Custom Header Type
|
||||
String
|
||||
|
||||
---
|
||||
## Architecture
|
||||
### Objective
|
||||
The main objective of a notary server is to perform notarization together with a prover. In this case, the prover can either be
|
||||
1. TCP client — which has access and control over the transport layer, i.e. TCP
|
||||
2. WebSocket client — which has no access over TCP and instead uses WebSocket for notarization
|
||||
|
||||
### Design Choices
|
||||
#### Web Framework
|
||||
Axum is chosen as the framework to serve HTTP and WebSocket requests from the prover clients due to its rich and well supported features, e.g. native integration with Tokio/Hyper/Tower, customizable middleware, ability to support lower level integration of TLS ([example](https://github.com/tokio-rs/axum/blob/main/examples/low-level-rustls/src/main.rs)). To simplify the notary server setup, a single Axum router is used to support both HTTP and WebSocket connections, i.e. all requests can be made to the same port of the notary server.
|
||||
|
||||
#### Notarization Configuration
|
||||
To perform notarization, some parameters need to be configured by the prover and notary server (more details in the [OpenAPI specification](./openapi.yaml)), i.e.
|
||||
- maximum transcript size
|
||||
- unique session id
|
||||
|
||||
To streamline this process, a single HTTP endpoint (`/session`) is used by both TCP and WebSocket clients.
|
||||
|
||||
#### WebSocket
|
||||
Axum's internal implementation of WebSocket uses [tokio_tungstenite](https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/), which provides a WebSocket struct that doesn't implement [AsyncRead](https://docs.rs/futures/latest/futures/io/trait.AsyncRead.html) and [AsyncWrite](https://docs.rs/futures/latest/futures/io/trait.AsyncWrite.html). Both these traits are required by TLSN core libraries for prover and notary. To overcome this, a [slight modification](./src/service/axum_websocket.rs) of Axum's implementation of WebSocket is used, where [async_tungstenite](https://docs.rs/async-tungstenite/latest/async_tungstenite/) is used instead so that [ws_stream_tungstenite](https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html) can be used to wrap on top of the WebSocket struct to get AsyncRead and AsyncWrite implemented.
|
||||
|
||||
@@ -3,6 +3,9 @@ server:
|
||||
domain: "127.0.0.1"
|
||||
port: 7047
|
||||
|
||||
notarization:
|
||||
max-transcript-size: 16384
|
||||
|
||||
tls-signature:
|
||||
private-key-pem-path: "./src/fixture/tls/notary.key"
|
||||
certificate-pem-path: "./src/fixture/tls/notary.crt"
|
||||
123
openapi.yaml
Normal file
123
openapi.yaml
Normal file
@@ -0,0 +1,123 @@
|
||||
openapi: 3.0.0
|
||||
|
||||
info:
|
||||
title: Notary Server
|
||||
description: Notary server written in Rust to provide notarization service.
|
||||
version: 0.1.0
|
||||
|
||||
tags:
|
||||
- name: Notarization
|
||||
|
||||
paths:
|
||||
/session:
|
||||
post:
|
||||
tags:
|
||||
- Notarization
|
||||
description: Initialize and configure notarization for both TCP and WebSocket clients
|
||||
parameters:
|
||||
- in: header
|
||||
name: Content-Type
|
||||
description: The value must be application/json
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- "application/json"
|
||||
required: true
|
||||
requestBody:
|
||||
description: Notarization session request to server
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/NotarizationSessionRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: Notarization session response from server
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/NotarizationSessionResponse"
|
||||
"400":
|
||||
description: Configuration parameters or headers provided by prover are invalid
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
example: "Invalid request from prover: Failed to deserialize the JSON body into the target type"
|
||||
"500":
|
||||
description: There was some internal error when processing
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
example: "Something is wrong"
|
||||
/notarize:
|
||||
get:
|
||||
tags:
|
||||
- Notarization
|
||||
description: Start notarization for TCP client
|
||||
parameters:
|
||||
- in: header
|
||||
name: Connection
|
||||
description: The value should be 'Upgrade'
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- "Upgrade"
|
||||
required: true
|
||||
- in: header
|
||||
name: Upgrade
|
||||
description: The value should be 'TCP'
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- "TCP"
|
||||
required: true
|
||||
- in: header
|
||||
name: X-Session-Id
|
||||
description: Unique ID returned from server upon calling POST /session
|
||||
schema:
|
||||
type: string
|
||||
required: true
|
||||
responses:
|
||||
"101":
|
||||
description: Switching protocol response
|
||||
"400":
|
||||
description: Headers provided by prover are invalid
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
example: "Invalid request from prover: Upgrade header is not set for client"
|
||||
"500":
|
||||
description: There was some internal error when processing
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
example: "Something is wrong"
|
||||
|
||||
components:
|
||||
schemas:
|
||||
NotarizationSessionRequest:
|
||||
type: object
|
||||
properties:
|
||||
clientType:
|
||||
description: Types of client that the prover is using
|
||||
type: string
|
||||
enum:
|
||||
- "Tcp"
|
||||
- "Websocket"
|
||||
maxTranscriptSize:
|
||||
description: Maximum transcript size in bytes
|
||||
type: integer
|
||||
required:
|
||||
- "clientType"
|
||||
- "maxTranscriptSize"
|
||||
NotarizationSessionResponse:
|
||||
type: object
|
||||
properties:
|
||||
sessionId:
|
||||
type: string
|
||||
required:
|
||||
- "sessionId"
|
||||
@@ -5,6 +5,8 @@ use serde::Deserialize;
|
||||
pub struct NotaryServerProperties {
|
||||
/// Name and address of the notary server
|
||||
pub server: ServerProperties,
|
||||
/// Setting for notarization
|
||||
pub notarization: NotarizationProperties,
|
||||
/// File path of private key and certificate (in PEM format) used for establishing TLS with prover
|
||||
pub tls_signature: TLSSignatureProperties,
|
||||
/// File path of private key (in PEM format) used to sign the notarisation
|
||||
@@ -13,6 +15,13 @@ pub struct NotaryServerProperties {
|
||||
pub tracing: TracingProperties,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct NotarizationProperties {
|
||||
/// Global limit for maximum transcript size in bytes
|
||||
pub max_transcript_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct ServerProperties {
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
pub mod cli;
|
||||
pub mod notary;
|
||||
|
||||
@@ -5,6 +5,6 @@ use structopt::StructOpt;
|
||||
#[structopt(name = "Notary Server")]
|
||||
pub struct CliFields {
|
||||
/// Configuration file location
|
||||
#[structopt(long, default_value = "./src/config/config.yaml")]
|
||||
#[structopt(long, default_value = "./config.yaml")]
|
||||
pub config_file: String,
|
||||
}
|
||||
|
||||
55
src/domain/notary.rs
Normal file
55
src/domain/notary.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use p256::ecdsa::SigningKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::config::NotarizationProperties;
|
||||
|
||||
/// Response object of the /session API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NotarizationSessionResponse {
|
||||
/// Unique session id that is generated by notary and shared to prover
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
/// Request object of the /session API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NotarizationSessionRequest {
|
||||
pub client_type: ClientType,
|
||||
/// Maximum transcript size in bytes
|
||||
pub max_transcript_size: Option<usize>,
|
||||
}
|
||||
|
||||
/// Types of client that the prover is using
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ClientType {
|
||||
/// Client that has access to the transport layer
|
||||
Tcp,
|
||||
/// Client that cannot directly access transport layer, e.g. browser extension
|
||||
Websocket,
|
||||
}
|
||||
|
||||
/// Global data that needs to be shared with the axum handlers
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NotaryGlobals {
|
||||
pub notary_signing_key: SigningKey,
|
||||
pub notarization_config: NotarizationProperties,
|
||||
/// A temporary storage to store configuration data, mainly used for WebSocket client
|
||||
pub store: Arc<Mutex<HashMap<String, Option<usize>>>>,
|
||||
}
|
||||
|
||||
impl NotaryGlobals {
|
||||
pub fn new(
|
||||
notary_signing_key: SigningKey,
|
||||
notarization_config: NotarizationProperties,
|
||||
) -> Self {
|
||||
Self {
|
||||
notary_signing_key,
|
||||
notarization_config,
|
||||
store: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
22
src/error.rs
22
src/error.rs
@@ -1,3 +1,7 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use eyre::Report;
|
||||
use std::error::Error;
|
||||
|
||||
@@ -11,6 +15,8 @@ pub enum NotaryServerError {
|
||||
Connection(String),
|
||||
#[error("Error occurred during notarization: {0}")]
|
||||
Notarization(Box<dyn Error + Send + 'static>),
|
||||
#[error("Invalid request from prover: {0}")]
|
||||
BadProverRequest(String),
|
||||
}
|
||||
|
||||
impl From<NotaryError> for NotaryServerError {
|
||||
@@ -24,3 +30,19 @@ impl From<NotaryConfigBuilderError> for NotaryServerError {
|
||||
Self::Notarization(Box::new(error))
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait implementation to convert this error into an axum http response
|
||||
impl IntoResponse for NotaryServerError {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
bad_request_error @ NotaryServerError::BadProverRequest(_) => {
|
||||
(StatusCode::BAD_REQUEST, bad_request_error.to_string()).into_response()
|
||||
}
|
||||
_ => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Something wrong happened.",
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
12
src/lib.rs
12
src/lib.rs
@@ -3,14 +3,18 @@ mod domain;
|
||||
mod error;
|
||||
mod server;
|
||||
mod server_tracing;
|
||||
mod service;
|
||||
mod util;
|
||||
|
||||
pub use config::{
|
||||
NotaryServerProperties, NotarySignatureProperties, ServerProperties, TLSSignatureProperties,
|
||||
TracingProperties,
|
||||
NotarizationProperties, NotaryServerProperties, NotarySignatureProperties, ServerProperties,
|
||||
TLSSignatureProperties, TracingProperties,
|
||||
};
|
||||
pub use domain::{
|
||||
cli::CliFields,
|
||||
notary::{ClientType, NotarizationSessionRequest, NotarizationSessionResponse},
|
||||
};
|
||||
pub use domain::cli::CliFields;
|
||||
pub use error::NotaryServerError;
|
||||
pub use server::{read_pem_file, run_tcp_server};
|
||||
pub use server::{read_pem_file, run_server};
|
||||
pub use server_tracing::init_tracing;
|
||||
pub use util::parse_config_file;
|
||||
|
||||
@@ -3,7 +3,7 @@ use structopt::StructOpt;
|
||||
use tracing::debug;
|
||||
|
||||
use notary_server::{
|
||||
init_tracing, parse_config_file, run_tcp_server, CliFields, NotaryServerError,
|
||||
init_tracing, parse_config_file, run_server, CliFields, NotaryServerError,
|
||||
NotaryServerProperties,
|
||||
};
|
||||
|
||||
@@ -18,8 +18,8 @@ async fn main() -> Result<(), NotaryServerError> {
|
||||
|
||||
debug!(?config, "Server config loaded");
|
||||
|
||||
// Run the tcp server
|
||||
run_tcp_server(&config).await?;
|
||||
// Run the server
|
||||
run_server(&config).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
115
src/server.rs
115
src/server.rs
@@ -1,47 +1,55 @@
|
||||
use axum::{
|
||||
http::{Request, StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use eyre::{ensure, eyre, Result};
|
||||
use futures_util::future::poll_fn;
|
||||
use p256::{
|
||||
ecdsa::{Signature, SigningKey},
|
||||
pkcs8::DecodePrivateKey,
|
||||
use hyper::server::{
|
||||
accept::Accept,
|
||||
conn::{AddrIncoming, Http},
|
||||
};
|
||||
use p256::{ecdsa::SigningKey, pkcs8::DecodePrivateKey};
|
||||
use rustls::{Certificate, PrivateKey, ServerConfig};
|
||||
use std::{
|
||||
fs::File as StdFile,
|
||||
io::BufReader,
|
||||
net::{IpAddr, SocketAddr},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use tlsn_notary::{bind_notary, NotaryConfig};
|
||||
use tokio::{
|
||||
fs::File,
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::TcpListener,
|
||||
};
|
||||
|
||||
use tokio::{fs::File, net::TcpListener};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use tower::MakeService;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::{
|
||||
config::{NotaryServerProperties, NotarySignatureProperties, TLSSignatureProperties},
|
||||
domain::notary::NotaryGlobals,
|
||||
error::NotaryServerError,
|
||||
service::{initialize, upgrade_protocol},
|
||||
};
|
||||
|
||||
/// Start a TLS-secured TCP server to accept notarization request
|
||||
/// Start a TLS-secured TCP server to accept notarization request for both TCP and WebSocket clients
|
||||
#[tracing::instrument(skip(config))]
|
||||
pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> {
|
||||
pub async fn run_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> {
|
||||
// Load the private key and cert needed for TLS connection from fixture folder — can be swapped out when we stop using static self signed cert
|
||||
let (tls_private_key, tls_certificates) = load_tls_key_and_cert(&config.tls_signature).await?;
|
||||
// Load the private key for notarized transcript signing from fixture folder — can be swapped out when we use proper ephemeral signing key
|
||||
let notary_signing_key = load_notary_signing_key(&config.notary_signature).await?;
|
||||
|
||||
// Build a TCP listener with TLS enabled
|
||||
let tls_config = Arc::new(
|
||||
ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(tls_certificates, tls_private_key)
|
||||
.map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?,
|
||||
);
|
||||
let mut server_config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(tls_certificates, tls_private_key)
|
||||
.map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?;
|
||||
|
||||
// Set the http protocols we support
|
||||
server_config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
let tls_config = Arc::new(server_config);
|
||||
|
||||
let notary_address = SocketAddr::new(
|
||||
IpAddr::V4(config.server.domain.parse().map_err(|err| {
|
||||
@@ -54,25 +62,42 @@ pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), Notar
|
||||
let listener = TcpListener::bind(notary_address)
|
||||
.await
|
||||
.map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?;
|
||||
let mut listener = AddrIncoming::from_listener(listener)
|
||||
.map_err(|err| eyre!("Failed to build hyper tcp listener: {err}"))?;
|
||||
|
||||
info!(
|
||||
"Listening for TLS-secured TCP traffic at {}",
|
||||
notary_address
|
||||
);
|
||||
|
||||
let protocol = Arc::new(Http::new());
|
||||
let notary_globals = NotaryGlobals::new(notary_signing_key, config.notarization.clone());
|
||||
let router = Router::new()
|
||||
.route(
|
||||
"/healthcheck",
|
||||
get(|| async move { (StatusCode::OK, "Ok").into_response() }),
|
||||
)
|
||||
.route("/session", post(initialize))
|
||||
.route("/notarize", get(upgrade_protocol))
|
||||
.with_state(notary_globals);
|
||||
let mut app = router.into_make_service();
|
||||
|
||||
loop {
|
||||
// Poll for any incoming connection constantly
|
||||
let (stream, prover_address) = match poll_fn(|cx| listener.poll_accept(cx)).await {
|
||||
Ok(connection) => connection,
|
||||
Err(err) => {
|
||||
error!("{}", NotaryServerError::Connection(err.to_string()));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Poll and await for any incoming connection, ensure that all operations inside are infallible to prevent bringing down the server
|
||||
let (prover_address, stream) =
|
||||
match poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)).await {
|
||||
Some(Ok(connection)) => (connection.remote_addr(), connection),
|
||||
Some(Err(err)) => {
|
||||
error!("{}", NotaryServerError::Connection(err.to_string()));
|
||||
continue;
|
||||
}
|
||||
None => unreachable!("The poll_accept method should never return None"),
|
||||
};
|
||||
debug!(?prover_address, "Received a prover's TCP connection");
|
||||
|
||||
let acceptor = acceptor.clone();
|
||||
let notary_signing_key = notary_signing_key.clone();
|
||||
let protocol = protocol.clone();
|
||||
let service = MakeService::<_, Request<hyper::Body>>::make_service(&mut app, &stream);
|
||||
|
||||
// Spawn a new async task to handle the new connection
|
||||
tokio::spawn(async move {
|
||||
@@ -82,16 +107,14 @@ pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), Notar
|
||||
?prover_address,
|
||||
"Accepted prover's TLS-secured TCP connection",
|
||||
);
|
||||
match notary_service(stream, &prover_address.to_string(), ¬ary_signing_key)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!(?prover_address, "Successful notarization!");
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?prover_address, "Failed notarization: {err}");
|
||||
}
|
||||
}
|
||||
// Serve different requests using the same hyper protocol and axum router
|
||||
let _ = protocol
|
||||
// Can unwrap because it's infallible
|
||||
.serve_connection(stream, service.await.unwrap())
|
||||
// use with_upgrades to upgrade connection to websocket for websocket clients
|
||||
// and to extract tcp connection for tcp clients
|
||||
.with_upgrades()
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
@@ -105,22 +128,6 @@ pub async fn run_tcp_server(config: &NotaryServerProperties) -> Result<(), Notar
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the notarization
|
||||
async fn notary_service<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: T,
|
||||
prover_address: &str,
|
||||
signing_key: &SigningKey,
|
||||
) -> Result<(), NotaryServerError> {
|
||||
debug!(?prover_address, "Starting notarization...");
|
||||
|
||||
// Temporarily use the prover address as the notarization session id as it is unique for each prover
|
||||
let config = NotaryConfig::builder().id(prover_address).build()?;
|
||||
let (notary, notary_fut) = bind_notary(config, socket.compat())?;
|
||||
|
||||
// Run the notary and background processes concurrently
|
||||
tokio::try_join!(notary_fut, notary.notarize::<Signature>(signing_key),).map(|_| Ok(()))?
|
||||
}
|
||||
|
||||
/// Temporary function to load notary signing key from static file
|
||||
async fn load_notary_signing_key(config: &NotarySignatureProperties) -> Result<SigningKey> {
|
||||
debug!("Loading notary server's signing key");
|
||||
|
||||
185
src/service.rs
Normal file
185
src/service.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
pub mod axum_websocket;
|
||||
pub mod tcp;
|
||||
pub mod websocket;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
extract::{rejection::JsonRejection, FromRequestParts, State},
|
||||
http::{header, request::Parts, HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Json, Response},
|
||||
};
|
||||
use axum_macros::debug_handler;
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
use tlsn_notary::{bind_notary, NotaryConfig};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use tracing::{debug, error, info, trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
domain::notary::{NotarizationSessionRequest, NotarizationSessionResponse, NotaryGlobals},
|
||||
error::NotaryServerError,
|
||||
service::{
|
||||
axum_websocket::{header_eq, WebSocketUpgrade},
|
||||
tcp::{tcp_notarize, TcpUpgrade},
|
||||
websocket::websocket_notarize,
|
||||
},
|
||||
};
|
||||
|
||||
/// A wrapper enum to facilitate extracting TCP connection for either WebSocket or TCP clients,
|
||||
/// so that we can use a single endpoint and handler for notarization for both types of clients
|
||||
pub enum ProtocolUpgrade {
|
||||
Tcp(TcpUpgrade),
|
||||
Ws(WebSocketUpgrade),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for ProtocolUpgrade
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = NotaryServerError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract tcp connection for websocket client
|
||||
if header_eq(&parts.headers, header::UPGRADE, "websocket") {
|
||||
let extractor = WebSocketUpgrade::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
|
||||
return Ok(Self::Ws(extractor));
|
||||
// Extract tcp connection for tcp client
|
||||
} else if header_eq(&parts.headers, header::UPGRADE, "tcp") {
|
||||
let extractor = TcpUpgrade::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
|
||||
return Ok(Self::Tcp(extractor));
|
||||
} else {
|
||||
return Err(NotaryServerError::BadProverRequest(
|
||||
"Upgrade header is not set for client".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handler to upgrade protocol from http to either websocket or underlying tcp depending on the type of client
|
||||
/// the session_id header is also extracted here to fetch the configuration parameters
|
||||
/// that have been submitted in the previous request to /session made by the same client
|
||||
pub async fn upgrade_protocol(
|
||||
protocol_upgrade: ProtocolUpgrade,
|
||||
mut headers: HeaderMap,
|
||||
State(notary_globals): State<NotaryGlobals>,
|
||||
) -> Response {
|
||||
info!("Received upgrade protocol request");
|
||||
// Extract the session_id from the headers
|
||||
let session_id = match headers.remove("X-Session-Id") {
|
||||
Some(session_id) => match session_id.to_str() {
|
||||
Ok(session_id) => session_id.to_string(),
|
||||
Err(err) => {
|
||||
let err_msg = format!("X-Session-Id header submitted is not a string: {}", err);
|
||||
error!(err_msg);
|
||||
return NotaryServerError::BadProverRequest(err_msg).into_response();
|
||||
}
|
||||
},
|
||||
None => {
|
||||
let err_msg = "Missing X-Session-Id in upgrade protocol request".to_string();
|
||||
error!(err_msg);
|
||||
return NotaryServerError::BadProverRequest(err_msg).into_response();
|
||||
}
|
||||
};
|
||||
// Fetch the configuration data from the store using the session_id
|
||||
let max_transcript_size = match notary_globals.store.lock().await.get(&session_id) {
|
||||
Some(max_transcript_size) => max_transcript_size.to_owned(),
|
||||
None => {
|
||||
let err_msg = format!("Session id {} does not exist", session_id);
|
||||
error!(err_msg);
|
||||
return NotaryServerError::BadProverRequest(err_msg).into_response();
|
||||
}
|
||||
};
|
||||
// This completes the HTTP Upgrade request and returns a successful response to the client, meanwhile initiating the websocket or tcp connection
|
||||
match protocol_upgrade {
|
||||
ProtocolUpgrade::Ws(ws) => ws.on_upgrade(move |socket| {
|
||||
websocket_notarize(socket, notary_globals, session_id, max_transcript_size)
|
||||
}),
|
||||
ProtocolUpgrade::Tcp(tcp) => tcp.on_upgrade(move |stream| {
|
||||
tcp_notarize(stream, notary_globals, session_id, max_transcript_size)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handler to initialize and configure notarization for both TCP and WebSocket clients
|
||||
#[debug_handler(state = NotaryGlobals)]
|
||||
pub async fn initialize(
|
||||
State(notary_globals): State<NotaryGlobals>,
|
||||
payload: Result<Json<NotarizationSessionRequest>, JsonRejection>,
|
||||
) -> impl IntoResponse {
|
||||
info!(
|
||||
?payload,
|
||||
"Received request for initializing a notarization session"
|
||||
);
|
||||
|
||||
// Parse the body payload
|
||||
let payload = match payload {
|
||||
Ok(payload) => payload,
|
||||
Err(err) => {
|
||||
error!("Malformed payload submitted for initializing notarization: {err}");
|
||||
return NotaryServerError::BadProverRequest(err.to_string()).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure that the max_transcript_size submitted is not larger than the global max limit configured in notary server
|
||||
if payload.max_transcript_size > Some(notary_globals.notarization_config.max_transcript_size) {
|
||||
error!(
|
||||
"Max transcript size requested {:?} exceeds the maximum threshold {:?}",
|
||||
payload.max_transcript_size, notary_globals.notarization_config.max_transcript_size
|
||||
);
|
||||
return NotaryServerError::BadProverRequest(
|
||||
"Max transcript size requested exceeds the maximum threshold".to_string(),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let prover_session_id = Uuid::new_v4().to_string();
|
||||
|
||||
// Store the configuration data in a temporary store
|
||||
notary_globals
|
||||
.store
|
||||
.lock()
|
||||
.await
|
||||
.insert(prover_session_id.clone(), payload.max_transcript_size);
|
||||
|
||||
trace!("Latest store state: {:?}", notary_globals.store);
|
||||
|
||||
// Return the session id in the response to the client
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(NotarizationSessionResponse {
|
||||
session_id: prover_session_id,
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Run the notarization
|
||||
pub async fn notary_service<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
socket: T,
|
||||
signing_key: &SigningKey,
|
||||
session_id: &str,
|
||||
max_transcript_size: Option<usize>,
|
||||
) -> Result<(), NotaryServerError> {
|
||||
debug!(?session_id, "Starting notarization...");
|
||||
|
||||
let mut config_builder = NotaryConfig::builder();
|
||||
|
||||
config_builder.id(session_id);
|
||||
|
||||
if let Some(max_transcript_size) = max_transcript_size {
|
||||
config_builder.max_transcript_size(max_transcript_size);
|
||||
}
|
||||
|
||||
let config = config_builder.build()?;
|
||||
|
||||
let (notary, notary_fut) = bind_notary(config, socket.compat())?;
|
||||
|
||||
// Run the notary and background processes concurrently
|
||||
tokio::try_join!(notary_fut, notary.notarize::<Signature>(signing_key),).map(|_| Ok(()))?
|
||||
}
|
||||
913
src/service/axum_websocket.rs
Normal file
913
src/service/axum_websocket.rs
Normal file
@@ -0,0 +1,913 @@
|
||||
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.6.19/axum/src/extract/ws.rs
|
||||
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
|
||||
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
|
||||
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
|
||||
//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other modification is commented with the prefix "NOTARY_MODIFICATION:"
|
||||
//!
|
||||
//! The code is under the following license:
|
||||
//!
|
||||
//! Copyright (c) 2019 Axum Contributors
|
||||
//!
|
||||
//! Permission is hereby granted, free of charge, to any
|
||||
//! person obtaining a copy of this software and associated
|
||||
//! documentation files (the "Software"), to deal in the
|
||||
//! Software without restriction, including without
|
||||
//! limitation the rights to use, copy, modify, merge,
|
||||
//! publish, distribute, sublicense, and/or sell copies of
|
||||
//! the Software, and to permit persons to whom the Software
|
||||
//! is furnished to do so, subject to the following
|
||||
//! conditions:
|
||||
//!
|
||||
//! The above copyright notice and this permission notice
|
||||
//! shall be included in all copies or substantial portions
|
||||
//! of the Software.
|
||||
//!
|
||||
//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
||||
//! ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
//! TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
//! PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
||||
//! SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
//! CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
//! OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
||||
//! IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
//! DEALINGS IN THE SOFTWARE.
|
||||
//!
|
||||
//!
|
||||
//! Handle WebSocket connections.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{
|
||||
//! extract::ws::{WebSocketUpgrade, WebSocket},
|
||||
//! routing::get,
|
||||
//! response::{IntoResponse, Response},
|
||||
//! Router,
|
||||
//! };
|
||||
//!
|
||||
//! let app = Router::new().route("/ws", get(handler));
|
||||
//!
|
||||
//! async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
//! ws.on_upgrade(handle_socket)
|
||||
//! }
|
||||
//!
|
||||
//! async fn handle_socket(mut socket: WebSocket) {
|
||||
//! while let Some(msg) = socket.recv().await {
|
||||
//! let msg = if let Ok(msg) = msg {
|
||||
//! msg
|
||||
//! } else {
|
||||
//! // client disconnected
|
||||
//! return;
|
||||
//! };
|
||||
//!
|
||||
//! if socket.send(msg).await.is_err() {
|
||||
//! // client disconnected
|
||||
//! return;
|
||||
//! }
|
||||
//! }
|
||||
//! }
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! # Passing data and/or state to an `on_upgrade` callback
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{
|
||||
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
|
||||
//! response::Response,
|
||||
//! routing::get,
|
||||
//! Router,
|
||||
//! };
|
||||
//!
|
||||
//! #[derive(Clone)]
|
||||
//! struct AppState {
|
||||
//! // ...
|
||||
//! }
|
||||
//!
|
||||
//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
|
||||
//! ws.on_upgrade(|socket| handle_socket(socket, state))
|
||||
//! }
|
||||
//!
|
||||
//! async fn handle_socket(socket: WebSocket, state: AppState) {
|
||||
//! // ...
|
||||
//! }
|
||||
//!
|
||||
//! let app = Router::new()
|
||||
//! .route("/ws", get(handler))
|
||||
//! .with_state(AppState { /* ... */ });
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! # Read and write concurrently
|
||||
//!
|
||||
//! If you need to read and write concurrently from a [`WebSocket`] you can use
|
||||
//! [`StreamExt::split`]:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use axum::{Error, extract::ws::{WebSocket, Message}};
|
||||
//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
|
||||
//!
|
||||
//! async fn handle_socket(mut socket: WebSocket) {
|
||||
//! let (mut sender, mut receiver) = socket.split();
|
||||
//!
|
||||
//! tokio::spawn(write(sender));
|
||||
//! tokio::spawn(read(receiver));
|
||||
//! }
|
||||
//!
|
||||
//! async fn read(receiver: SplitStream<WebSocket>) {
|
||||
//! // ...
|
||||
//! }
|
||||
//!
|
||||
//! async fn write(sender: SplitSink<WebSocket, Message>) {
|
||||
//! // ...
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
|
||||
|
||||
#![allow(unused)]
|
||||
|
||||
use self::rejection::*;
|
||||
use async_trait::async_trait;
|
||||
use async_tungstenite::{
|
||||
tokio::TokioAdapter,
|
||||
tungstenite::{
|
||||
self as ts,
|
||||
protocol::{self, WebSocketConfig},
|
||||
},
|
||||
WebSocketStream,
|
||||
};
|
||||
use axum::{
|
||||
body::{self, Bytes},
|
||||
extract::FromRequestParts,
|
||||
response::Response,
|
||||
Error,
|
||||
};
|
||||
|
||||
use futures_util::{
|
||||
sink::{Sink, SinkExt},
|
||||
stream::{Stream, StreamExt},
|
||||
};
|
||||
use http::{
|
||||
header::{self, HeaderMap, HeaderName, HeaderValue},
|
||||
request::Parts,
|
||||
Method, StatusCode,
|
||||
};
|
||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
/// Extractor for establishing WebSocket connections.
|
||||
///
|
||||
/// Note: This extractor requires the request method to be `GET` so it should
|
||||
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
|
||||
/// rejected.
|
||||
///
|
||||
/// See the [module docs](self) for an example.
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||
pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
|
||||
config: WebSocketConfig,
|
||||
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
|
||||
protocol: Option<HeaderValue>,
|
||||
sec_websocket_key: HeaderValue,
|
||||
on_upgrade: OnUpgrade,
|
||||
on_failed_upgrade: F,
|
||||
sec_websocket_protocol: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WebSocketUpgrade")
|
||||
.field("config", &self.config)
|
||||
.field("protocol", &self.protocol)
|
||||
.field("sec_websocket_key", &self.sec_websocket_key)
|
||||
.field("sec_websocket_protocol", &self.sec_websocket_protocol)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> WebSocketUpgrade<F> {
|
||||
/// Set the size of the internal message send queue.
|
||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||
self.config.max_send_queue = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum message size (defaults to 64 megabytes)
|
||||
pub fn max_message_size(mut self, max: usize) -> Self {
|
||||
self.config.max_message_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum frame size (defaults to 16 megabytes)
|
||||
pub fn max_frame_size(mut self, max: usize) -> Self {
|
||||
self.config.max_frame_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Allow server to accept unmasked frames (defaults to false)
|
||||
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
|
||||
self.config.accept_unmasked_frames = accept;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the known protocols.
|
||||
///
|
||||
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
|
||||
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
|
||||
/// return the protocol name.
|
||||
///
|
||||
/// The protocols should be listed in decreasing order of preference: if the client offers
|
||||
/// multiple protocols that the server could support, the server will pick the first one in
|
||||
/// this list.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// extract::ws::{WebSocketUpgrade, WebSocket},
|
||||
/// routing::get,
|
||||
/// response::{IntoResponse, Response},
|
||||
/// Router,
|
||||
/// };
|
||||
///
|
||||
/// let app = Router::new().route("/ws", get(handler));
|
||||
///
|
||||
/// async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
|
||||
/// .on_upgrade(|socket| async {
|
||||
/// // ...
|
||||
/// })
|
||||
/// }
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
pub fn protocols<I>(mut self, protocols: I) -> Self
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: Into<Cow<'static, str>>,
|
||||
{
|
||||
if let Some(req_protocols) = self
|
||||
.sec_websocket_protocol
|
||||
.as_ref()
|
||||
.and_then(|p| p.to_str().ok())
|
||||
{
|
||||
self.protocol = protocols
|
||||
.into_iter()
|
||||
// FIXME: This will often allocate a new `String` and so is less efficient than it
|
||||
// could be. But that can't be fixed without breaking changes to the public API.
|
||||
.map(Into::into)
|
||||
.find(|protocol| {
|
||||
req_protocols
|
||||
.split(',')
|
||||
.any(|req_protocol| req_protocol.trim() == protocol)
|
||||
})
|
||||
.map(|protocol| match protocol {
|
||||
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
|
||||
Cow::Borrowed(s) => HeaderValue::from_static(s),
|
||||
});
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a callback to call if upgrading the connection fails.
|
||||
///
|
||||
/// The connection upgrade is performed in a background task. If that fails this callback
|
||||
/// will be called.
|
||||
///
|
||||
/// By default any errors will be silently ignored.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// extract::{WebSocketUpgrade},
|
||||
/// response::Response,
|
||||
/// };
|
||||
///
|
||||
/// async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
/// ws.on_failed_upgrade(|error| {
|
||||
/// report_error(error);
|
||||
/// })
|
||||
/// .on_upgrade(|socket| async { /* ... */ })
|
||||
/// }
|
||||
/// #
|
||||
/// # fn report_error(_: axum::Error) {}
|
||||
/// ```
|
||||
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
|
||||
where
|
||||
C: OnFailedUpdgrade,
|
||||
{
|
||||
WebSocketUpgrade {
|
||||
config: self.config,
|
||||
protocol: self.protocol,
|
||||
sec_websocket_key: self.sec_websocket_key,
|
||||
on_upgrade: self.on_upgrade,
|
||||
on_failed_upgrade: callback,
|
||||
sec_websocket_protocol: self.sec_websocket_protocol,
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize upgrading the connection and call the provided callback with
|
||||
/// the stream.
|
||||
#[must_use = "to setup the WebSocket connection, this response must be returned"]
|
||||
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
|
||||
where
|
||||
C: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
F: OnFailedUpdgrade,
|
||||
{
|
||||
let on_upgrade = self.on_upgrade;
|
||||
let config = self.config;
|
||||
let on_failed_upgrade = self.on_failed_upgrade;
|
||||
|
||||
let protocol = self.protocol.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = match on_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(err) => {
|
||||
error!("Something wrong with on_upgrade: {:?}", err);
|
||||
on_failed_upgrade.call(Error::new(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
let socket = WebSocketStream::from_raw_socket(
|
||||
// NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite
|
||||
TokioAdapter::new(upgraded),
|
||||
protocol::Role::Server,
|
||||
Some(config),
|
||||
)
|
||||
.await;
|
||||
let socket = WebSocket {
|
||||
inner: socket,
|
||||
protocol,
|
||||
};
|
||||
callback(socket).await;
|
||||
});
|
||||
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(header::CONNECTION, UPGRADE)
|
||||
.header(header::UPGRADE, WEBSOCKET)
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(self.sec_websocket_key.as_bytes()),
|
||||
);
|
||||
|
||||
if let Some(protocol) = self.protocol {
|
||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
|
||||
builder.body(body::boxed(body::Empty::new())).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// What to do when a connection upgrade fails.
|
||||
///
|
||||
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
|
||||
pub trait OnFailedUpdgrade: Send + 'static {
|
||||
/// Call the callback.
|
||||
fn call(self, error: Error);
|
||||
}
|
||||
|
||||
impl<F> OnFailedUpdgrade for F
|
||||
where
|
||||
F: FnOnce(Error) + Send + 'static,
|
||||
{
|
||||
fn call(self, error: Error) {
|
||||
self(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
|
||||
///
|
||||
/// It simply ignores the error.
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug)]
|
||||
pub struct DefaultOnFailedUpdgrade;
|
||||
|
||||
impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
|
||||
#[inline]
|
||||
fn call(self, _error: Error) {}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = WebSocketUpgradeRejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
if parts.method != Method::GET {
|
||||
return Err(MethodNotGet.into());
|
||||
}
|
||||
|
||||
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
|
||||
return Err(InvalidConnectionHeader.into());
|
||||
}
|
||||
|
||||
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
|
||||
return Err(InvalidUpgradeHeader.into());
|
||||
}
|
||||
|
||||
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
|
||||
return Err(InvalidWebSocketVersionHeader.into());
|
||||
}
|
||||
|
||||
let sec_websocket_key = parts
|
||||
.headers
|
||||
.get(header::SEC_WEBSOCKET_KEY)
|
||||
.ok_or(WebSocketKeyHeaderMissing)?
|
||||
.clone();
|
||||
|
||||
let on_upgrade = parts
|
||||
.extensions
|
||||
.remove::<OnUpgrade>()
|
||||
.ok_or(ConnectionNotUpgradable)?;
|
||||
|
||||
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
|
||||
|
||||
Ok(Self {
|
||||
config: Default::default(),
|
||||
protocol: None,
|
||||
sec_websocket_key,
|
||||
on_upgrade,
|
||||
sec_websocket_protocol,
|
||||
on_failed_upgrade: DefaultOnFailedUpdgrade,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
|
||||
if let Some(header) = headers.get(&key) {
|
||||
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
|
||||
let header = if let Some(header) = headers.get(&key) {
|
||||
header
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
|
||||
header.to_ascii_lowercase().contains(value)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of WebSocket messages.
|
||||
///
|
||||
/// See [the module level documentation](self) for more details.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
inner: WebSocketStream<TokioAdapter<Upgraded>>,
|
||||
protocol: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
impl WebSocket {
|
||||
/// Consume `self` and get the inner [`async_tungstenite::WebSocketStream`].
|
||||
pub fn into_inner(self) -> WebSocketStream<TokioAdapter<Upgraded>> {
|
||||
self.inner
|
||||
}
|
||||
|
||||
/// Receive another message.
|
||||
///
|
||||
/// Returns `None` if the stream has closed.
|
||||
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
|
||||
self.next().await
|
||||
}
|
||||
|
||||
/// Send a message.
|
||||
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
|
||||
self.inner
|
||||
.send(msg.into_tungstenite())
|
||||
.await
|
||||
.map_err(Error::new)
|
||||
}
|
||||
|
||||
/// Gracefully close this WebSocket.
|
||||
pub async fn close(mut self) -> Result<(), Error> {
|
||||
self.inner.close(None).await.map_err(Error::new)
|
||||
}
|
||||
|
||||
/// Return the selected WebSocket subprotocol, if one has been chosen.
|
||||
pub fn protocol(&self) -> Option<&HeaderValue> {
|
||||
self.protocol.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WebSocket {
|
||||
type Item = Result<Message, Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
|
||||
Some(Ok(msg)) => {
|
||||
if let Some(msg) = Message::from_tungstenite(msg) {
|
||||
return Poll::Ready(Some(Ok(msg)));
|
||||
}
|
||||
}
|
||||
Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for WebSocket {
|
||||
type Error = Error;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.inner)
|
||||
.start_send(item.into_tungstenite())
|
||||
.map_err(Error::new)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
|
||||
}
|
||||
}
|
||||
|
||||
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
|
||||
pub type CloseCode = u16;
|
||||
|
||||
/// A struct representing the close command.
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct CloseFrame<'t> {
|
||||
/// The reason as a code.
|
||||
pub code: CloseCode,
|
||||
/// The reason as text string.
|
||||
pub reason: Cow<'t, str>,
|
||||
}
|
||||
|
||||
/// A WebSocket message.
|
||||
//
|
||||
// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
|
||||
// Copyright (c) 2017 Alexey Galakhov
|
||||
// Copyright (c) 2016 Jason Housley
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
// THE SOFTWARE.
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum Message {
|
||||
/// A text WebSocket message
|
||||
Text(String),
|
||||
/// A binary WebSocket message
|
||||
Binary(Vec<u8>),
|
||||
/// A ping message with the specified payload
|
||||
///
|
||||
/// The payload here must have a length less than 125 bytes.
|
||||
///
|
||||
/// Ping messages will be automatically responded to by the server, so you do not have to worry
|
||||
/// about dealing with them yourself.
|
||||
Ping(Vec<u8>),
|
||||
/// A pong message with the specified payload
|
||||
///
|
||||
/// The payload here must have a length less than 125 bytes.
|
||||
///
|
||||
/// Pong messages will be automatically sent to the client if a ping message is received, so
|
||||
/// you do not have to worry about constructing them yourself unless you want to implement a
|
||||
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
|
||||
Pong(Vec<u8>),
|
||||
/// A close message with the optional close frame.
|
||||
Close(Option<CloseFrame<'static>>),
|
||||
}
|
||||
|
||||
impl Message {
|
||||
fn into_tungstenite(self) -> ts::Message {
|
||||
match self {
|
||||
Self::Text(text) => ts::Message::Text(text),
|
||||
Self::Binary(binary) => ts::Message::Binary(binary),
|
||||
Self::Ping(ping) => ts::Message::Ping(ping),
|
||||
Self::Pong(pong) => ts::Message::Pong(pong),
|
||||
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
|
||||
code: ts::protocol::frame::coding::CloseCode::from(close.code),
|
||||
reason: close.reason,
|
||||
})),
|
||||
Self::Close(None) => ts::Message::Close(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_tungstenite(message: ts::Message) -> Option<Self> {
|
||||
match message {
|
||||
ts::Message::Text(text) => Some(Self::Text(text)),
|
||||
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
|
||||
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
|
||||
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
|
||||
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
|
||||
code: close.code.into(),
|
||||
reason: close.reason,
|
||||
}))),
|
||||
ts::Message::Close(None) => Some(Self::Close(None)),
|
||||
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
|
||||
// https://github.com/snapview/tungstenite-rs/issues/268
|
||||
ts::Message::Frame(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume the WebSocket and return it as binary data.
|
||||
pub fn into_data(self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Text(string) => string.into_bytes(),
|
||||
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
|
||||
Self::Close(None) => Vec::new(),
|
||||
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to consume the WebSocket message and convert it to a String.
|
||||
pub fn into_text(self) -> Result<String, Error> {
|
||||
match self {
|
||||
Self::Text(string) => Ok(string),
|
||||
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
|
||||
.map_err(|err| err.utf8_error())
|
||||
.map_err(Error::new)?),
|
||||
Self::Close(None) => Ok(String::new()),
|
||||
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to get a &str from the WebSocket message,
|
||||
/// this will try to convert binary data to utf8.
|
||||
pub fn to_text(&self) -> Result<&str, Error> {
|
||||
match *self {
|
||||
Self::Text(ref string) => Ok(string),
|
||||
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
|
||||
Ok(std::str::from_utf8(data).map_err(Error::new)?)
|
||||
}
|
||||
Self::Close(None) => Ok(""),
|
||||
Self::Close(Some(ref frame)) => Ok(&frame.reason),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Message {
|
||||
fn from(string: String) -> Self {
|
||||
Message::Text(string)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<&'s str> for Message {
|
||||
fn from(string: &'s str) -> Self {
|
||||
Message::Text(string.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'b> From<&'b [u8]> for Message {
|
||||
fn from(data: &'b [u8]) -> Self {
|
||||
Message::Binary(data.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for Message {
|
||||
fn from(data: Vec<u8>) -> Self {
|
||||
Message::Binary(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Message> for Vec<u8> {
|
||||
fn from(msg: Message) -> Self {
|
||||
msg.into_data()
|
||||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
use base64::engine::Engine as _;
|
||||
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||
let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
|
||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||
}
|
||||
|
||||
pub mod rejection {
|
||||
//! WebSocket specific rejections.
|
||||
|
||||
use axum_core::__composite_rejection as composite_rejection;
|
||||
use axum_core::__define_rejection as define_rejection;
|
||||
|
||||
define_rejection! {
|
||||
#[status = METHOD_NOT_ALLOWED]
|
||||
#[body = "Request method must be `GET`"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct MethodNotGet;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Connection header did not include 'upgrade'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidConnectionHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Upgrade` header did not include 'websocket'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidUpgradeHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidWebSocketVersionHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Sec-WebSocket-Key` header missing"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct WebSocketKeyHeaderMissing;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = UPGRADE_REQUIRED]
|
||||
#[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
///
|
||||
/// This rejection is returned if the connection cannot be upgraded for example if the
|
||||
/// request is HTTP/1.0.
|
||||
///
|
||||
/// See [MDN] for more details about connection upgrades.
|
||||
///
|
||||
/// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
|
||||
pub struct ConnectionNotUpgradable;
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
///
|
||||
/// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
|
||||
/// extractor can fail.
|
||||
pub enum WebSocketUpgradeRejection {
|
||||
MethodNotGet,
|
||||
InvalidConnectionHeader,
|
||||
InvalidUpgradeHeader,
|
||||
InvalidWebSocketVersionHeader,
|
||||
WebSocketKeyHeaderMissing,
|
||||
ConnectionNotUpgradable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod close_code {
|
||||
//! Constants for [`CloseCode`]s.
|
||||
//!
|
||||
//! [`CloseCode`]: super::CloseCode
|
||||
|
||||
/// Indicates a normal closure, meaning that the purpose for which the connection was
|
||||
/// established has been fulfilled.
|
||||
pub const NORMAL: u16 = 1000;
|
||||
|
||||
/// Indicates that an endpoint is "going away", such as a server going down or a browser having
|
||||
/// navigated away from a page.
|
||||
pub const AWAY: u16 = 1001;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection due to a protocol error.
|
||||
pub const PROTOCOL: u16 = 1002;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received a type of
|
||||
/// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
|
||||
/// it receives a binary message).
|
||||
pub const UNSUPPORTED: u16 = 1003;
|
||||
|
||||
/// Indicates that no status code was included in a closing frame.
|
||||
pub const STATUS: u16 = 1005;
|
||||
|
||||
/// Indicates an abnormal closure.
|
||||
pub const ABNORMAL: u16 = 1006;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received data
|
||||
/// within a message that was not consistent with the type of the message (e.g., non-UTF-8
|
||||
/// RFC3629 data within a text message).
|
||||
pub const INVALID: u16 = 1007;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received a message
|
||||
/// that violates its policy. This is a generic status code that can be returned when there is
|
||||
/// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
|
||||
/// hide specific details about the policy.
|
||||
pub const POLICY: u16 = 1008;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received a message
|
||||
/// that is too big for it to process.
|
||||
pub const SIZE: u16 = 1009;
|
||||
|
||||
/// Indicates that an endpoint (client) is terminating the connection because it has expected
|
||||
/// the server to negotiate one or more extension, but the server didn't return them in the
|
||||
/// response message of the WebSocket handshake. The list of extensions that are needed should
|
||||
/// be given as the reason for closing. Note that this status code is not used by the server,
|
||||
/// because it can fail the WebSocket handshake instead.
|
||||
pub const EXTENSION: u16 = 1010;
|
||||
|
||||
/// Indicates that a server is terminating the connection because it encountered an unexpected
|
||||
/// condition that prevented it from fulfilling the request.
|
||||
pub const ERROR: u16 = 1011;
|
||||
|
||||
/// Indicates that the server is restarting.
|
||||
pub const RESTART: u16 = 1012;
|
||||
|
||||
/// Indicates that the server is overloaded and the client should either connect to a different
|
||||
/// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
|
||||
/// action.
|
||||
pub const AGAIN: u16 = 1013;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::{body::Body, routing::get, Router};
|
||||
use http::{Request, Version};
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_http_1_0_requests() {
|
||||
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
|
||||
let rejection = ws.unwrap_err();
|
||||
assert!(matches!(
|
||||
rejection,
|
||||
WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
|
||||
));
|
||||
std::future::ready(())
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.version(Version::HTTP_10)
|
||||
.method(Method::GET)
|
||||
.header("upgrade", "websocket")
|
||||
.header("connection", "Upgrade")
|
||||
.header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
|
||||
.header("sec-websocket-version", "13")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let res = svc.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn default_on_failed_upgrade() {
|
||||
async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_upgrade(|_| async {})
|
||||
}
|
||||
let _: Router = Router::new().route("/", get(handler));
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn on_failed_upgrade() {
|
||||
async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
|
||||
.on_upgrade(|_| async {})
|
||||
}
|
||||
let _: Router = Router::new().route("/", get(handler));
|
||||
}
|
||||
}
|
||||
101
src/service/tcp.rs
Normal file
101
src/service/tcp.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body,
|
||||
extract::FromRequestParts,
|
||||
http::{header, request::Parts, HeaderValue, StatusCode},
|
||||
response::Response,
|
||||
};
|
||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use std::future::Future;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::{domain::notary::NotaryGlobals, service::notary_service, NotaryServerError};
|
||||
|
||||
/// Custom extractor used to extract underlying TCP connection for TCP client — using the same upgrade primitives used by
|
||||
/// the WebSocket implementation where the underlying TCP connection (wrapped in an Upgraded object) only gets polled as an OnUpgrade future
|
||||
/// after the ongoing HTTP request is finished (ref: https://github.com/tokio-rs/axum/blob/a6a849bb5b96a2f641fa077fe76f70ad4d20341c/axum/src/extract/ws.rs#L122)
|
||||
///
|
||||
/// More info on the upgrade primitives: https://docs.rs/hyper/latest/hyper/upgrade/index.html
|
||||
pub struct TcpUpgrade {
|
||||
pub on_upgrade: OnUpgrade,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for TcpUpgrade
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = NotaryServerError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let on_upgrade =
|
||||
parts
|
||||
.extensions
|
||||
.remove::<OnUpgrade>()
|
||||
.ok_or(NotaryServerError::BadProverRequest(
|
||||
"Upgrade header is not set for TCP client".to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self { on_upgrade })
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpUpgrade {
|
||||
/// Utility function to complete the http upgrade protocol by
|
||||
/// (1) Return 101 switching protocol response to client to indicate the switching to TCP
|
||||
/// (2) Spawn a new thread to await on the OnUpgrade object to claim the underlying TCP connection
|
||||
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
|
||||
where
|
||||
C: FnOnce(Upgraded) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let on_upgrade = self.on_upgrade;
|
||||
tokio::spawn(async move {
|
||||
let upgraded = match on_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(err) => {
|
||||
error!("Something wrong with upgrading HTTP: {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
callback(upgraded).await;
|
||||
});
|
||||
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const TCP: HeaderValue = HeaderValue::from_static("tcp");
|
||||
|
||||
let builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(header::CONNECTION, UPGRADE)
|
||||
.header(header::UPGRADE, TCP);
|
||||
|
||||
builder.body(body::boxed(body::Empty::new())).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform notarization using the extracted tcp connection
|
||||
pub async fn tcp_notarize(
|
||||
stream: Upgraded,
|
||||
notary_globals: NotaryGlobals,
|
||||
session_id: String,
|
||||
max_transcript_size: Option<usize>,
|
||||
) {
|
||||
debug!(?session_id, "Upgraded to tcp connection");
|
||||
match notary_service(
|
||||
stream,
|
||||
¬ary_globals.notary_signing_key,
|
||||
&session_id,
|
||||
max_transcript_size,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!(?session_id, "Successful notarization using tcp!");
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?session_id, "Failed notarization using tcp: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
34
src/service/websocket.rs
Normal file
34
src/service/websocket.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use tracing::{debug, error, info};
|
||||
use ws_stream_tungstenite::WsStream;
|
||||
|
||||
use crate::{
|
||||
domain::notary::NotaryGlobals,
|
||||
service::{axum_websocket::WebSocket, notary_service},
|
||||
};
|
||||
|
||||
/// Perform notarization using the established websocket connection
|
||||
pub async fn websocket_notarize(
|
||||
socket: WebSocket,
|
||||
notary_globals: NotaryGlobals,
|
||||
session_id: String,
|
||||
max_transcript_size: Option<usize>,
|
||||
) {
|
||||
debug!(?session_id, "Upgraded to websocket connection");
|
||||
// Wrap the websocket in WsStream so that we have AsyncRead and AsyncWrite implemented
|
||||
let stream = WsStream::new(socket.into_inner());
|
||||
match notary_service(
|
||||
stream,
|
||||
¬ary_globals.notary_signing_key,
|
||||
&session_id,
|
||||
max_transcript_size,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!(?session_id, "Successful notarization using websocket!");
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?session_id, "Failed notarization using websocket: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,7 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_parse_config_file() {
|
||||
let location = "./src/config/config.yaml";
|
||||
let location = "./config.yaml";
|
||||
let config: Result<NotaryServerProperties> = parse_config_file(location);
|
||||
assert!(
|
||||
config.is_ok(),
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
use async_tungstenite::{
|
||||
tokio::connect_async_with_tls_connector_and_config, tungstenite::protocol::WebSocketConfig,
|
||||
};
|
||||
use futures::AsyncWriteExt;
|
||||
use hyper::{body::to_bytes, Body, Request, StatusCode};
|
||||
use hyper::{
|
||||
body::to_bytes,
|
||||
client::{conn::Parts, HttpConnector},
|
||||
Body, Client, Request, StatusCode,
|
||||
};
|
||||
use hyper_tls::HttpsConnector;
|
||||
use rustls::{Certificate, ClientConfig, RootCertStore};
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
@@ -11,21 +19,26 @@ use tlsn_prover::{bind_prover, ProverConfig};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use tracing::debug;
|
||||
use ws_stream_tungstenite::WsStream;
|
||||
|
||||
use notary_server::{
|
||||
read_pem_file, run_tcp_server, NotaryServerProperties, NotarySignatureProperties,
|
||||
read_pem_file, run_server, NotarizationProperties, NotarizationSessionRequest,
|
||||
NotarizationSessionResponse, NotaryServerProperties, NotarySignatureProperties,
|
||||
ServerProperties, TLSSignatureProperties, TracingProperties,
|
||||
};
|
||||
|
||||
const NOTARY_CA_CERT_PATH: &str = "./src/fixture/tls/rootCA.crt";
|
||||
const NOTARY_CA_CERT_BYTES: &[u8] = include_bytes!("../src/fixture/tls/rootCA.crt");
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_notarization() {
|
||||
async fn setup_config_and_server(sleep_ms: u64, port: u16) -> NotaryServerProperties {
|
||||
let notary_config = NotaryServerProperties {
|
||||
server: ServerProperties {
|
||||
name: "tlsnotaryserver.io".to_string(),
|
||||
domain: "127.0.0.1".to_string(),
|
||||
port: 7047,
|
||||
port,
|
||||
},
|
||||
notarization: NotarizationProperties {
|
||||
max_transcript_size: 1 << 14,
|
||||
},
|
||||
tls_signature: TLSSignatureProperties {
|
||||
private_key_pem_path: "./src/fixture/tls/notary.key".to_string(),
|
||||
@@ -39,24 +52,26 @@ async fn test_notarization() {
|
||||
},
|
||||
};
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let config = notary_config.clone();
|
||||
|
||||
// Run the the notary server
|
||||
// Run the notary server
|
||||
tokio::spawn(async move {
|
||||
run_tcp_server(&config).await.unwrap();
|
||||
run_server(&config).await.unwrap();
|
||||
});
|
||||
|
||||
// Sleep for a while to allow notary server to finish set up and start listening
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
|
||||
// Run the prover
|
||||
run_prover(¬ary_config).await;
|
||||
notary_config
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(notary_config))]
|
||||
async fn run_prover(notary_config: &NotaryServerProperties) {
|
||||
#[tokio::test]
|
||||
async fn test_tcp_prover() {
|
||||
// Notary server configuration setup
|
||||
let notary_config = setup_config_and_server(100, 7048).await;
|
||||
|
||||
// Connect to the Notary via TLS-TCP
|
||||
let mut certificate_file_reader = read_pem_file(NOTARY_CA_CERT_PATH).await.unwrap();
|
||||
let mut certificates: Vec<Certificate> = rustls_pemfile::certs(&mut certificate_file_reader)
|
||||
@@ -75,14 +90,15 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
|
||||
.with_no_client_auth();
|
||||
let notary_connector = TlsConnector::from(Arc::new(client_notary_config));
|
||||
|
||||
let notary_domain = notary_config.server.domain.clone();
|
||||
let notary_port = notary_config.server.port;
|
||||
let notary_socket = tokio::net::TcpStream::connect(SocketAddr::new(
|
||||
IpAddr::V4(notary_config.server.domain.parse().unwrap()),
|
||||
notary_config.server.port,
|
||||
IpAddr::V4(notary_domain.parse().unwrap()),
|
||||
notary_port,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let prover_address = notary_socket.local_addr().unwrap().to_string();
|
||||
let notary_tls_socket = notary_connector
|
||||
.connect(
|
||||
notary_config.server.name.as_str().try_into().unwrap(),
|
||||
@@ -91,6 +107,77 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Attach the hyper HTTP client to the notary TLS connection to send request to the /session endpoint to configure notarization and obtain session id
|
||||
let (mut request_sender, connection) = hyper::client::conn::handshake(notary_tls_socket)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Spawn the HTTP task to be run concurrently
|
||||
let connection_task = tokio::spawn(connection.without_shutdown());
|
||||
|
||||
// Build the HTTP request to configure notarization
|
||||
let payload = serde_json::to_string(&NotarizationSessionRequest {
|
||||
client_type: notary_server::ClientType::Tcp,
|
||||
max_transcript_size: Some(notary_config.notarization.max_transcript_size),
|
||||
})
|
||||
.unwrap();
|
||||
let request = Request::builder()
|
||||
.uri(format!("https://{notary_domain}:{notary_port}/session"))
|
||||
.method("POST")
|
||||
.header("Host", notary_domain.clone())
|
||||
// Need to specify application/json for axum to parse it as json
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(payload))
|
||||
.unwrap();
|
||||
|
||||
debug!("Sending configuration request");
|
||||
|
||||
let response = request_sender.send_request(request).await.unwrap();
|
||||
|
||||
debug!("Sent configuration request");
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
debug!("Response OK");
|
||||
|
||||
// Pretty printing :)
|
||||
let payload = to_bytes(response.into_body()).await.unwrap().to_vec();
|
||||
let notarization_response =
|
||||
serde_json::from_str::<NotarizationSessionResponse>(&String::from_utf8_lossy(&payload))
|
||||
.unwrap();
|
||||
|
||||
debug!("Notarization response: {:?}", notarization_response,);
|
||||
|
||||
// Send notarization request via HTTP, where the underlying TCP connection will be extracted later
|
||||
let request = Request::builder()
|
||||
.uri(format!("https://{notary_domain}:{notary_port}/notarize"))
|
||||
.method("GET")
|
||||
.header("Host", notary_domain)
|
||||
.header("Connection", "Upgrade")
|
||||
// Need to specify this upgrade header for server to extract tcp connection later
|
||||
.header("Upgrade", "TCP")
|
||||
// Need to specify the session_id so that notary server knows the right configuration to use
|
||||
// as the configuration is set in the previous HTTP call
|
||||
.header("X-Session-Id", notarization_response.session_id.clone())
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
debug!("Sending notarization request");
|
||||
|
||||
let response = request_sender.send_request(request).await.unwrap();
|
||||
|
||||
debug!("Sent notarization request");
|
||||
|
||||
assert!(response.status() == StatusCode::SWITCHING_PROTOCOLS);
|
||||
|
||||
debug!("Switched protocol OK");
|
||||
|
||||
// Claim back the TCP socket after HTTP exchange is done so that client can use it for notarization
|
||||
let Parts {
|
||||
io: notary_tls_socket,
|
||||
..
|
||||
} = connection_task.await.unwrap().unwrap();
|
||||
|
||||
// Connect to the Server
|
||||
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
|
||||
let server_task = tokio::spawn(bind_test_server(server_socket.compat()));
|
||||
@@ -100,9 +187,9 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
|
||||
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
|
||||
.unwrap();
|
||||
|
||||
// Basic default prover config — use local address as the notarization session id
|
||||
// Basic default prover config — use the responded session id from notary server
|
||||
let prover_config = ProverConfig::builder()
|
||||
.id(prover_address)
|
||||
.id(notarization_response.session_id)
|
||||
.server_dns(SERVER_DOMAIN)
|
||||
.root_cert_store(root_store)
|
||||
.build()
|
||||
@@ -167,3 +254,169 @@ async fn run_prover(notary_config: &NotaryServerProperties) {
|
||||
|
||||
debug!("Done notarization!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_websocket_prover() {
|
||||
// Notary server configuration setup
|
||||
let notary_config = setup_config_and_server(100, 7049).await;
|
||||
let notary_domain = notary_config.server.domain.clone();
|
||||
let notary_port = notary_config.server.port;
|
||||
|
||||
// Connect to the notary server via TLS-WebSocket
|
||||
// Try to avoid dealing with transport layer directly to mimic the limitation of a browser extension that uses websocket
|
||||
//
|
||||
// Establish TLS setup for connections later
|
||||
let certificate =
|
||||
tokio_native_tls::native_tls::Certificate::from_pem(NOTARY_CA_CERT_BYTES).unwrap();
|
||||
let notary_tls_connector = tokio_native_tls::native_tls::TlsConnector::builder()
|
||||
.add_root_certificate(certificate)
|
||||
.use_sni(false)
|
||||
.danger_accept_invalid_certs(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Call the /session HTTP API to configure notarization and obtain session id
|
||||
let mut hyper_http_connector = HttpConnector::new();
|
||||
hyper_http_connector.enforce_http(false);
|
||||
let mut hyper_tls_connector =
|
||||
HttpsConnector::from((hyper_http_connector, notary_tls_connector.clone().into()));
|
||||
hyper_tls_connector.https_only(true);
|
||||
let https_client = Client::builder().build::<_, hyper::Body>(hyper_tls_connector);
|
||||
|
||||
// Build the HTTP request to configure notarization
|
||||
let payload = serde_json::to_string(&NotarizationSessionRequest {
|
||||
client_type: notary_server::ClientType::Websocket,
|
||||
max_transcript_size: Some(notary_config.notarization.max_transcript_size),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let request = Request::builder()
|
||||
.uri(format!("https://{notary_domain}:{notary_port}/session"))
|
||||
.method("POST")
|
||||
.header("Host", notary_domain.clone())
|
||||
// Need to specify application/json for axum to parse it as json
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(payload))
|
||||
.unwrap();
|
||||
|
||||
debug!("Sending request");
|
||||
|
||||
let response = https_client.request(request).await.unwrap();
|
||||
|
||||
debug!("Sent request");
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
debug!("Response OK");
|
||||
|
||||
// Pretty printing :)
|
||||
let payload = to_bytes(response.into_body()).await.unwrap().to_vec();
|
||||
let notarization_response =
|
||||
serde_json::from_str::<NotarizationSessionResponse>(&String::from_utf8_lossy(&payload))
|
||||
.unwrap();
|
||||
|
||||
debug!("Notarization response: {:?}", notarization_response,);
|
||||
|
||||
// Connect to the Notary via TLS-Websocket
|
||||
//
|
||||
// Note: This will establish a new TLS-TCP connection instead of reusing the previous TCP connection
|
||||
// used in the previous HTTP POST request because we cannot claim back the tcp connection used in hyper
|
||||
// client while using its high level request function — there does not seem to have a crate that can let you
|
||||
// make a request without establishing TCP connection where you can claim the TCP connection later after making the request
|
||||
let request = http::Request::builder()
|
||||
.uri(format!("wss://{notary_domain}:{notary_port}/notarize"))
|
||||
.header("Host", notary_domain.clone())
|
||||
.header("Sec-WebSocket-Key", uuid::Uuid::new_v4().to_string())
|
||||
.header("Sec-WebSocket-Version", "13")
|
||||
.header("Connection", "Upgrade")
|
||||
.header("Upgrade", "Websocket")
|
||||
// Need to specify the session_id so that notary server knows the right configuration to use
|
||||
// as the configuration is set in the previous HTTP call
|
||||
.header("X-Session-Id", notarization_response.session_id.clone())
|
||||
.body(())
|
||||
.unwrap();
|
||||
|
||||
let (notary_ws_stream, _) = connect_async_with_tls_connector_and_config(
|
||||
request,
|
||||
Some(notary_tls_connector.into()),
|
||||
Some(WebSocketConfig::default()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wrap the socket with the adapter so that we get AsyncRead and AsyncWrite implemented
|
||||
let notary_ws_socket = WsStream::new(notary_ws_stream);
|
||||
|
||||
// Connect to the Server
|
||||
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
|
||||
let server_task = tokio::spawn(bind_test_server(server_socket.compat()));
|
||||
|
||||
let mut root_store = tls_core::anchors::RootCertStore::empty();
|
||||
root_store
|
||||
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
|
||||
.unwrap();
|
||||
|
||||
// Basic default prover config — use the responded session id from notary server
|
||||
let prover_config = ProverConfig::builder()
|
||||
.id(notarization_response.session_id)
|
||||
.server_dns(SERVER_DOMAIN)
|
||||
.root_cert_store(root_store)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Bind the Prover to the sockets
|
||||
let (tls_connection, prover_fut, mux_fut) =
|
||||
bind_prover(prover_config, client_socket.compat(), notary_ws_socket)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Spawn the Prover and Mux tasks to be run concurrently
|
||||
tokio::spawn(mux_fut);
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
|
||||
let (mut request_sender, connection) = hyper::client::conn::handshake(tls_connection.compat())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let connection_task = tokio::spawn(connection.without_shutdown());
|
||||
|
||||
let request = Request::builder()
|
||||
.uri(format!("https://{}/echo", SERVER_DOMAIN))
|
||||
.header("Host", SERVER_DOMAIN)
|
||||
.header("Connection", "close")
|
||||
.method("POST")
|
||||
.body(Body::from("echo"))
|
||||
.unwrap();
|
||||
|
||||
debug!("Sending request to server: {:?}", request);
|
||||
|
||||
let response = request_sender.send_request(request).await.unwrap();
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
debug!(
|
||||
"Received response from server: {:?}",
|
||||
String::from_utf8_lossy(&to_bytes(response.into_body()).await.unwrap())
|
||||
);
|
||||
|
||||
let mut server_tls_conn = server_task.await.unwrap().unwrap();
|
||||
|
||||
// Make sure the server closes cleanly (sends close notify)
|
||||
server_tls_conn.close().await.unwrap();
|
||||
|
||||
let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner();
|
||||
|
||||
client_socket.close().await.unwrap();
|
||||
|
||||
let mut prover = prover_task.await.unwrap().unwrap();
|
||||
|
||||
let sent_len = prover.sent_transcript().data().len();
|
||||
let recv_len = prover.recv_transcript().data().len();
|
||||
|
||||
prover.add_commitment_sent(0..sent_len as u32).unwrap();
|
||||
prover.add_commitment_recv(0..recv_len as u32).unwrap();
|
||||
|
||||
_ = prover.finalize().await.unwrap();
|
||||
|
||||
debug!("Done notarization!");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user