This commit is contained in:
maceip
2024-02-01 06:24:24 +00:00
parent f58075dde6
commit 7b6cb7b5df
17 changed files with 4 additions and 2004 deletions

View File

@@ -7,47 +7,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
async-trait = "0.1.67"
async-tungstenite = { version = "0.22.2", features = ["tokio-native-tls"] }
axum = { version = "0.6.18", features = ["ws"] }
axum-core = "0.3.4"
axum-macros = "0.3.8"
base64 = "0.21.0"
chrono = "0.4.31"
csv = "1.3.0"
eyre = "0.6.8"
futures = "0.3"
futures-util = "0.3.28"
http = "0.2.9"
hyper = { version = "0.14", features = ["client", "http1", "server", "tcp"] }
opentelemetry = { version = "0.19" }
p256 = "0.13"
rstest = "0.18"
rustls = { version = "0.21" }
rustls-pemfile = { version = "1.0.2" }
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0"
serde_yaml = "0.9.21"
sha1 = "0.10"
structopt = "0.3.26"
thiserror = "1"
tlsn-verifier = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
tlsn-tls-core = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
tokio = { version = "1", features = ["full"] }
tokio-rustls = { version = "0.24.1" }
tokio-util = { version = "0.7", features = ["compat"] }
tower = { version = "0.4.12", features = ["make"] }
tower-http = { version = "0.4.4", features = ["cors"] }
eyre = "0.6.8"
tracing = "0.1"
tracing-opentelemetry = "0.19"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.4.1", features = ["v4", "fast-rng"] }
ws_stream_tungstenite = { version = "0.10.0", features = ["tokio_io"] }
[dev-dependencies]
# specify vendored feature to use statically linked copy of OpenSSL
hyper-tls = { version = "0.5.0", features = ["vendored"] }
tls-server-fixture = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
tokio-native-tls = { version = "0.3.1", features = ["vendored"] }
structopt = "0.3.26"
notary-server = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919"}

View File

@@ -1,66 +0,0 @@
use serde::Deserialize;
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct NotaryServerProperties {
/// Name and address of the notary server
pub server: ServerProperties,
/// Setting for notarization
pub notarization: NotarizationProperties,
/// Setting for TLS connection between prover and notary
pub tls: TLSProperties,
/// File path of private key (in PEM format) used to sign the notarization
pub notary_key: NotarySigningKeyProperties,
/// Setting for logging/tracing
pub tracing: TracingProperties,
/// Setting for authorization
pub authorization: AuthorizationProperties,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct AuthorizationProperties {
/// Switch to turn on or off auth middleware
pub enabled: bool,
/// File path of the whitelist API key csv
pub whitelist_csv_path: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct NotarizationProperties {
/// Global limit for maximum transcript size in bytes
pub max_transcript_size: usize,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct ServerProperties {
/// Used for testing purpose
pub name: String,
pub host: String,
pub port: u16,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct TLSProperties {
/// Flag to turn on/off TLS between prover and notary (should always be turned on unless TLS is handled by external setup e.g. reverse proxy, cloud)
pub enabled: bool,
pub private_key_pem_path: String,
pub certificate_pem_path: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct NotarySigningKeyProperties {
pub private_key_pem_path: String,
pub public_key_pem_path: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct TracingProperties {
/// The minimum logging level, must be either of <https://docs.rs/tracing/latest/tracing/struct.Level.html#implementations>
pub default_level: String,
}

View File

@@ -1,19 +0,0 @@
pub mod auth;
pub mod cli;
pub mod notary;
use serde::{Deserialize, Serialize};
/// Response object of the /info API
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InfoResponse {
/// Current version of notary-server
pub version: String,
/// Public key of the notary signing key
pub public_key: String,
/// Current git commit hash of notary-server
pub git_commit_hash: String,
/// Current git commit timestamp of notary-server
pub git_commit_timestamp: String,
}

View File

@@ -1,22 +0,0 @@
use serde::Deserialize;
use std::collections::HashMap;
/// Structure of each whitelisted record of the API key whitelist for authorization purpose
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct AuthorizationWhitelistRecord {
pub name: String,
pub api_key: String,
pub created_at: String,
}
/// Convert whitelist data structure from vector to hashmap using api_key as the key to speed up lookup
pub fn authorization_whitelist_vec_into_hashmap(
authorization_whitelist: Vec<AuthorizationWhitelistRecord>,
) -> HashMap<String, AuthorizationWhitelistRecord> {
let mut hashmap = HashMap::new();
authorization_whitelist.iter().for_each(|record| {
hashmap.insert(record.api_key.clone(), record.to_owned());
});
hashmap
}

View File

@@ -1,10 +0,0 @@
use structopt::StructOpt;
/// Fields loaded from the command line when launching this server.
#[derive(Clone, Debug, StructOpt)]
#[structopt(name = "Notary Server")]
pub struct CliFields {
/// Configuration file location
#[structopt(long, default_value = "./config/config.yaml")]
pub config_file: String,
}

View File

@@ -1,75 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use chrono::{DateTime, Utc};
use p256::ecdsa::SigningKey;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::{config::NotarizationProperties, domain::auth::AuthorizationWhitelistRecord};
/// Response object of the /session API
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NotarizationSessionResponse {
/// Unique session id that is generated by notary and shared to prover
pub session_id: String,
}
/// Request object of the /session API
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NotarizationSessionRequest {
pub client_type: ClientType,
/// Maximum transcript size in bytes
pub max_transcript_size: Option<usize>,
}
/// Request query of the /notarize API
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NotarizationRequestQuery {
/// Session id that is returned from /session API
pub session_id: String,
}
/// Types of client that the prover is using
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ClientType {
/// Client that has access to the transport layer
Tcp,
/// Client that cannot directly access transport layer, e.g. browser extension
Websocket,
}
/// Session configuration data to be stored in temporary storage
#[derive(Clone, Debug)]
pub struct SessionData {
pub max_transcript_size: Option<usize>,
pub created_at: DateTime<Utc>,
}
/// Global data that needs to be shared with the axum handlers
#[derive(Clone, Debug)]
pub struct NotaryGlobals {
pub notary_signing_key: SigningKey,
pub notarization_config: NotarizationProperties,
/// A temporary storage to store configuration data, mainly used for WebSocket client
pub store: Arc<Mutex<HashMap<String, SessionData>>>,
/// Whitelist of API keys for authorization purpose
pub authorization_whitelist: Option<Arc<HashMap<String, AuthorizationWhitelistRecord>>>,
}
impl NotaryGlobals {
pub fn new(
notary_signing_key: SigningKey,
notarization_config: NotarizationProperties,
authorization_whitelist: Option<Arc<HashMap<String, AuthorizationWhitelistRecord>>>,
) -> Self {
Self {
notary_signing_key,
notarization_config,
store: Default::default(),
authorization_whitelist,
}
}
}

View File

@@ -1,55 +0,0 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use eyre::Report;
use std::error::Error;
use tlsn_verifier::tls::{VerifierConfigBuilderError, VerifierError};
#[derive(Debug, thiserror::Error)]
pub enum NotaryServerError {
#[error(transparent)]
Unexpected(#[from] Report),
#[error("Failed to connect to prover: {0}")]
Connection(String),
#[error("Error occurred during notarization: {0}")]
Notarization(Box<dyn Error + Send + 'static>),
#[error("Invalid request from prover: {0}")]
BadProverRequest(String),
#[error("Unauthorized request from prover: {0}")]
UnauthorizedProverRequest(String),
}
impl From<VerifierError> for NotaryServerError {
fn from(error: VerifierError) -> Self {
Self::Notarization(Box::new(error))
}
}
impl From<VerifierConfigBuilderError> for NotaryServerError {
fn from(error: VerifierConfigBuilderError) -> Self {
Self::Notarization(Box::new(error))
}
}
/// Trait implementation to convert this error into an axum http response
impl IntoResponse for NotaryServerError {
fn into_response(self) -> Response {
match self {
bad_request_error @ NotaryServerError::BadProverRequest(_) => {
(StatusCode::BAD_REQUEST, bad_request_error.to_string()).into_response()
}
unauthorized_request_error @ NotaryServerError::UnauthorizedProverRequest(_) => (
StatusCode::UNAUTHORIZED,
unauthorized_request_error.to_string(),
)
.into_response(),
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
"Something wrong happened.",
)
.into_response(),
}
}
}

View File

@@ -1,21 +0,0 @@
mod config;
mod domain;
mod error;
mod middleware;
mod server;
mod server_tracing;
mod service;
mod util;
pub use config::{
AuthorizationProperties, NotarizationProperties, NotaryServerProperties,
NotarySigningKeyProperties, ServerProperties, TLSProperties, TracingProperties,
};
pub use domain::{
cli::CliFields,
notary::{ClientType, NotarizationSessionRequest, NotarizationSessionResponse},
};
pub use error::NotaryServerError;
pub use server::{read_pem_file, run_server};
pub use server_tracing::init_tracing;
pub use util::parse_config_file;

View File

@@ -2,7 +2,7 @@ use eyre::{eyre, Result};
use structopt::StructOpt;
use tracing::debug;
use sgx_notary_server::{
use notary_server::{
init_tracing, parse_config_file, run_server, CliFields, NotaryServerError,
NotaryServerProperties,
};

View File

@@ -1,104 +0,0 @@
use async_trait::async_trait;
use axum::http::{header, request::Parts};
use axum_core::extract::{FromRef, FromRequestParts};
use std::collections::HashMap;
use tracing::{error, trace};
use crate::{
domain::{auth::AuthorizationWhitelistRecord, notary::NotaryGlobals},
NotaryServerError,
};
/// Auth middleware to prevent DOS
pub struct AuthorizationMiddleware;
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
NotaryGlobals: FromRef<S>,
S: Send + Sync,
{
type Rejection = NotaryServerError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let notary_globals = NotaryGlobals::from_ref(state);
let Some(whitelist) = notary_globals.authorization_whitelist else {
trace!("Skipping authorization as whitelist is not set.");
return Ok(Self);
};
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|value| std::str::from_utf8(value.as_bytes()).ok());
match auth_header {
Some(auth_header) => {
if api_key_is_valid(auth_header, &whitelist) {
trace!("Request authorized.");
Ok(Self)
} else {
let err_msg = "Invalid API key.".to_string();
error!(err_msg);
Err(NotaryServerError::UnauthorizedProverRequest(err_msg))
}
}
None => {
let err_msg = "Missing API key.".to_string();
error!(err_msg);
Err(NotaryServerError::UnauthorizedProverRequest(err_msg))
}
}
}
}
/// Helper function to check if an API key is in whitelist
fn api_key_is_valid(
api_key: &str,
whitelist: &HashMap<String, AuthorizationWhitelistRecord>,
) -> bool {
whitelist.get(api_key).is_some()
}
#[cfg(test)]
mod test {
use super::{api_key_is_valid, HashMap};
use crate::domain::auth::{
authorization_whitelist_vec_into_hashmap, AuthorizationWhitelistRecord,
};
use std::sync::Arc;
fn get_whitelist_fixture() -> HashMap<String, AuthorizationWhitelistRecord> {
authorization_whitelist_vec_into_hashmap(vec![
AuthorizationWhitelistRecord {
name: "test-name-0".to_string(),
api_key: "test-api-key-0".to_string(),
created_at: "2023-10-18T07:38:53Z".to_string(),
},
AuthorizationWhitelistRecord {
name: "test-name-1".to_string(),
api_key: "test-api-key-1".to_string(),
created_at: "2023-10-11T07:38:53Z".to_string(),
},
AuthorizationWhitelistRecord {
name: "test-name-2".to_string(),
api_key: "test-api-key-2".to_string(),
created_at: "2022-10-11T07:38:53Z".to_string(),
},
])
}
#[test]
fn test_api_key_is_present() {
let whitelist = get_whitelist_fixture();
assert!(api_key_is_valid("test-api-key-0", &Arc::new(whitelist)));
}
#[test]
fn test_api_key_is_absent() {
let whitelist = get_whitelist_fixture();
assert_eq!(
api_key_is_valid("test-api-keY-0", &Arc::new(whitelist)),
false
);
}
}

View File

@@ -1,273 +0,0 @@
use axum::{
http::{Request, StatusCode},
middleware::from_extractor_with_state,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use eyre::{ensure, eyre, Result};
use futures_util::future::poll_fn;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, Http},
};
use p256::{ecdsa::SigningKey, pkcs8::DecodePrivateKey};
use rustls::{Certificate, PrivateKey, ServerConfig};
use std::{
fs::File as StdFile,
io::BufReader,
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
};
use tower_http::cors::CorsLayer;
use tokio::{fs::File, net::TcpListener};
use tokio_rustls::TlsAcceptor;
use tower::MakeService;
use tracing::{debug, error, info};
use crate::{
config::{NotaryServerProperties, NotarySigningKeyProperties},
domain::{
auth::{authorization_whitelist_vec_into_hashmap, AuthorizationWhitelistRecord},
notary::NotaryGlobals,
InfoResponse,
},
error::NotaryServerError,
middleware::AuthorizationMiddleware,
service::{initialize, upgrade_protocol},
util::parse_csv_file,
};
/// Start a TCP server (with or without TLS) to accept notarization request for both TCP and WebSocket clients
#[tracing::instrument(skip(config))]
pub async fn run_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> {
// Load the private key for notarized transcript signing
let notary_signing_key = load_notary_signing_key(&config.notary_key).await?;
// Build TLS acceptor if it is turned on
let tls_acceptor = if !config.tls.enabled {
debug!("Skipping TLS setup as it is turned off.");
None
} else {
let (tls_private_key, tls_certificates) = load_tls_key_and_cert(
&config.tls.private_key_pem_path,
&config.tls.certificate_pem_path,
)
.await?;
let mut server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(tls_certificates, tls_private_key)
.map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?;
// Set the http protocols we support
server_config.alpn_protocols = vec![b"http/1.1".to_vec()];
let tls_config = Arc::new(server_config);
Some(TlsAcceptor::from(tls_config))
};
// Load the authorization whitelist csv if it is turned on
let authorization_whitelist = if !config.authorization.enabled {
debug!("Skipping authorization as it is turned off.");
None
} else {
// Load the csv
let whitelist_csv = parse_csv_file::<AuthorizationWhitelistRecord>(
&config.authorization.whitelist_csv_path,
)
.map_err(|err| eyre!("Failed to parse authorization whitelist csv: {:?}", err))?;
// Convert the whitelist record into hashmap for faster lookup
Some(authorization_whitelist_vec_into_hashmap(whitelist_csv))
};
let notary_address = SocketAddr::new(
IpAddr::V4(config.server.host.parse().map_err(|err| {
eyre!("Failed to parse notary host address from server config: {err}")
})?),
config.server.port,
);
let listener = TcpListener::bind(notary_address)
.await
.map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?;
let mut listener = AddrIncoming::from_listener(listener)
.map_err(|err| eyre!("Failed to build hyper tcp listener: {err}"))?;
info!("Listening for TCP traffic at {}", notary_address);
let protocol = Arc::new(Http::new());
let notary_globals = NotaryGlobals::new(
notary_signing_key,
config.notarization.clone(),
// Use Arc to prevent cloning the whitelist for every request
authorization_whitelist.map(Arc::new),
);
// Parameters needed for the info endpoint
let public_key = std::fs::read_to_string(&config.notary_key.public_key_pem_path)
.map_err(|err| eyre!("Failed to load notary public signing key for notarization: {err}"))?;
let version = env!("CARGO_PKG_VERSION").to_string();
let git_commit_hash = env!("GIT_COMMIT_HASH").to_string();
let git_commit_timestamp = env!("GIT_COMMIT_TIMESTAMP").to_string();
let router = Router::new()
.route(
"/healthcheck",
get(|| async move { (StatusCode::OK, "Ok").into_response() }),
)
.route(
"/info",
get(|| async move {
(
StatusCode::OK,
Json(InfoResponse {
version,
public_key,
git_commit_hash,
git_commit_timestamp,
}),
)
.into_response()
}),
)
.route("/session", post(initialize))
// Not applying auth middleware to /notarize endpoint for now as we can rely on our
// short-lived session id generated from /session endpoint, as it is not possible
// to use header for API key for websocket /notarize endpoint due to browser restriction
// ref: https://stackoverflow.com/a/4361358; And putting it in url query param
// seems to be more insecured: https://stackoverflow.com/questions/5517281/place-api-key-in-headers-or-url
.route_layer(from_extractor_with_state::<
AuthorizationMiddleware,
NotaryGlobals,
>(notary_globals.clone()))
.route("/notarize", get(upgrade_protocol))
.layer(CorsLayer::permissive())
.with_state(notary_globals);
let mut app = router.into_make_service();
loop {
// Poll and await for any incoming connection, ensure that all operations inside are infallible to prevent bringing down the server
let (prover_address, stream) =
match poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)).await {
Some(Ok(connection)) => (connection.remote_addr(), connection),
Some(Err(err)) => {
error!("{}", NotaryServerError::Connection(err.to_string()));
continue;
}
None => unreachable!("The poll_accept method should never return None"),
};
debug!(?prover_address, "Received a prover's TCP connection");
let tls_acceptor = tls_acceptor.clone();
let protocol = protocol.clone();
let service = MakeService::<_, Request<hyper::Body>>::make_service(&mut app, &stream);
// Spawn a new async task to handle the new connection
tokio::spawn(async move {
// When TLS is enabled
if let Some(acceptor) = tls_acceptor {
match acceptor.accept(stream).await {
Ok(stream) => {
info!(
?prover_address,
"Accepted prover's TLS-secured TCP connection",
);
// Serve different requests using the same hyper protocol and axum router
let _ = protocol
// Can unwrap because it's infallible
.serve_connection(stream, service.await.unwrap())
// use with_upgrades to upgrade connection to websocket for websocket clients
// and to extract tcp connection for tcp clients
.with_upgrades()
.await;
}
Err(err) => {
error!(
?prover_address,
"{}",
NotaryServerError::Connection(err.to_string())
);
}
}
} else {
// When TLS is disabled
info!(?prover_address, "Accepted prover's TCP connection",);
// Serve different requests using the same hyper protocol and axum router
let _ = protocol
// Can unwrap because it's infallible
.serve_connection(stream, service.await.unwrap())
// use with_upgrades to upgrade connection to websocket for websocket clients
// and to extract tcp connection for tcp clients
.with_upgrades()
.await;
}
});
}
}
/// Temporary function to load notary signing key from static file
async fn load_notary_signing_key(config: &NotarySigningKeyProperties) -> Result<SigningKey> {
debug!("Loading notary server's signing key");
let notary_signing_key = SigningKey::read_pkcs8_pem_file(&config.private_key_pem_path)
.map_err(|err| eyre!("Failed to load notary signing key for notarization: {err}"))?;
debug!("Successfully loaded notary server's signing key!");
Ok(notary_signing_key)
}
/// Read a PEM-formatted file and return its buffer reader
pub async fn read_pem_file(file_path: &str) -> Result<BufReader<StdFile>> {
let key_file = File::open(file_path).await?.into_std().await;
Ok(BufReader::new(key_file))
}
/// Load notary tls private key and cert from static files
async fn load_tls_key_and_cert(
private_key_pem_path: &str,
certificate_pem_path: &str,
) -> Result<(PrivateKey, Vec<Certificate>)> {
debug!("Loading notary server's tls private key and certificate");
let mut private_key_file_reader = read_pem_file(private_key_pem_path).await?;
let mut private_keys = rustls_pemfile::pkcs8_private_keys(&mut private_key_file_reader)?;
ensure!(
private_keys.len() == 1,
"More than 1 key found in the tls private key pem file"
);
let private_key = PrivateKey(private_keys.remove(0));
let mut certificate_file_reader = read_pem_file(certificate_pem_path).await?;
let certificates = rustls_pemfile::certs(&mut certificate_file_reader)?
.into_iter()
.map(Certificate)
.collect();
debug!("Successfully loaded notary server's tls private key and certificate!");
Ok((private_key, certificates))
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn test_load_notary_key_and_cert() {
let private_key_pem_path = "./fixture/tls/notary.key";
let certificate_pem_path = "./fixture/tls/notary.crt";
let result: Result<(PrivateKey, Vec<Certificate>)> =
load_tls_key_and_cert(private_key_pem_path, certificate_pem_path).await;
assert!(result.is_ok(), "Could not load tls private key and cert");
}
#[tokio::test]
async fn test_load_notary_signing_key() {
let config = NotarySigningKeyProperties {
private_key_pem_path: "./fixture/notary/notary.key".to_string(),
public_key_pem_path: "./fixture/notary/notary.pub".to_string(),
};
let result: Result<SigningKey> = load_notary_signing_key(&config).await;
assert!(result.is_ok(), "Could not load notary private key");
}
}

View File

@@ -1,37 +0,0 @@
use eyre::Result;
use opentelemetry::{
global,
sdk::{export::trace::stdout, propagation::TraceContextPropagator},
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry};
use crate::config::NotaryServerProperties;
pub fn init_tracing(config: &NotaryServerProperties) -> Result<()> {
// Create a new OpenTelemetry pipeline
let tracer = stdout::new_pipeline().install_simple();
// Create a tracing layer with the configured tracer
let tracing_layer = tracing_opentelemetry::layer().with_tracer(tracer);
// Set the log level
let env_filter_layer = EnvFilter::new(&config.tracing.default_level);
// Format the log
let format_layer = tracing_subscriber::fmt::layer()
// Use a more compact, abbreviated log format
.compact()
.with_thread_ids(true)
.with_thread_names(true);
// Set up context propagation
global::set_text_map_propagator(TraceContextPropagator::default());
Registry::default()
.with(tracing_layer)
.with(env_filter_layer)
.with(format_layer)
.try_init()?;
Ok(())
}

View File

@@ -1,178 +0,0 @@
pub mod axum_websocket;
pub mod tcp;
pub mod websocket;
use async_trait::async_trait;
use axum::{
extract::{rejection::JsonRejection, FromRequestParts, Query, State},
http::{header, request::Parts, StatusCode},
response::{IntoResponse, Json, Response},
};
use axum_macros::debug_handler;
use chrono::Utc;
use p256::ecdsa::{Signature, SigningKey};
use tlsn_verifier::tls::{Verifier, VerifierConfig};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::{debug, error, info, trace};
use uuid::Uuid;
use crate::{
domain::notary::{
NotarizationRequestQuery, NotarizationSessionRequest, NotarizationSessionResponse,
NotaryGlobals, SessionData,
},
error::NotaryServerError,
service::{
axum_websocket::{header_eq, WebSocketUpgrade},
tcp::{tcp_notarize, TcpUpgrade},
websocket::websocket_notarize,
},
};
/// A wrapper enum to facilitate extracting TCP connection for either WebSocket or TCP clients,
/// so that we can use a single endpoint and handler for notarization for both types of clients
pub enum ProtocolUpgrade {
Tcp(TcpUpgrade),
Ws(WebSocketUpgrade),
}
#[async_trait]
impl<S> FromRequestParts<S> for ProtocolUpgrade
where
S: Send + Sync,
{
type Rejection = NotaryServerError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Extract tcp connection for websocket client
if header_eq(&parts.headers, header::UPGRADE, "websocket") {
let extractor = WebSocketUpgrade::from_request_parts(parts, state)
.await
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
return Ok(Self::Ws(extractor));
// Extract tcp connection for tcp client
} else if header_eq(&parts.headers, header::UPGRADE, "tcp") {
let extractor = TcpUpgrade::from_request_parts(parts, state)
.await
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
return Ok(Self::Tcp(extractor));
} else {
return Err(NotaryServerError::BadProverRequest(
"Upgrade header is not set for client".to_string(),
));
}
}
}
/// Handler to upgrade protocol from http to either websocket or underlying tcp depending on the type of client
/// the session_id parameter is also extracted here to fetch the configuration parameters
/// that have been submitted in the previous request to /session made by the same client
pub async fn upgrade_protocol(
protocol_upgrade: ProtocolUpgrade,
State(notary_globals): State<NotaryGlobals>,
Query(params): Query<NotarizationRequestQuery>,
) -> Response {
info!("Received upgrade protocol request");
let session_id = params.session_id;
// Fetch the configuration data from the store using the session_id
// This also removes the configuration data from the store as each session_id can only be used once
let max_transcript_size = match notary_globals.store.lock().await.remove(&session_id) {
Some(data) => data.max_transcript_size,
None => {
let err_msg = format!("Session id {} does not exist", session_id);
error!(err_msg);
return NotaryServerError::BadProverRequest(err_msg).into_response();
}
};
// This completes the HTTP Upgrade request and returns a successful response to the client, meanwhile initiating the websocket or tcp connection
match protocol_upgrade {
ProtocolUpgrade::Ws(ws) => ws.on_upgrade(move |socket| {
websocket_notarize(socket, notary_globals, session_id, max_transcript_size)
}),
ProtocolUpgrade::Tcp(tcp) => tcp.on_upgrade(move |stream| {
tcp_notarize(stream, notary_globals, session_id, max_transcript_size)
}),
}
}
/// Handler to initialize and configure notarization for both TCP and WebSocket clients
#[debug_handler(state = NotaryGlobals)]
pub async fn initialize(
State(notary_globals): State<NotaryGlobals>,
payload: Result<Json<NotarizationSessionRequest>, JsonRejection>,
) -> impl IntoResponse {
info!(
?payload,
"Received request for initializing a notarization session"
);
// Parse the body payload
let payload = match payload {
Ok(payload) => payload,
Err(err) => {
error!("Malformed payload submitted for initializing notarization: {err}");
return NotaryServerError::BadProverRequest(err.to_string()).into_response();
}
};
// Ensure that the max_transcript_size submitted is not larger than the global max limit configured in notary server
if payload.max_transcript_size > Some(notary_globals.notarization_config.max_transcript_size) {
error!(
"Max transcript size requested {:?} exceeds the maximum threshold {:?}",
payload.max_transcript_size, notary_globals.notarization_config.max_transcript_size
);
return NotaryServerError::BadProverRequest(
"Max transcript size requested exceeds the maximum threshold".to_string(),
)
.into_response();
}
let prover_session_id = Uuid::new_v4().to_string();
// Store the configuration data in a temporary store
notary_globals.store.lock().await.insert(
prover_session_id.clone(),
SessionData {
max_transcript_size: payload.max_transcript_size,
created_at: Utc::now(),
},
);
trace!("Latest store state: {:?}", notary_globals.store);
// Return the session id in the response to the client
(
StatusCode::OK,
Json(NotarizationSessionResponse {
session_id: prover_session_id,
}),
)
.into_response()
}
/// Run the notarization
pub async fn notary_service<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
socket: T,
signing_key: &SigningKey,
session_id: &str,
max_transcript_size: Option<usize>,
) -> Result<(), NotaryServerError> {
debug!(?session_id, "Starting notarization...");
let mut config_builder = VerifierConfig::builder();
config_builder = config_builder.id(session_id);
if let Some(max_transcript_size) = max_transcript_size {
config_builder = config_builder.max_transcript_size(max_transcript_size);
}
let config = config_builder.build()?;
Verifier::new(config)
.notarize::<_, Signature>(socket.compat(), signing_key)
.await?;
Ok(())
}

View File

@@ -1,914 +0,0 @@
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.6.19/axum/src/extract/ws.rs
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other modification is commented with the prefix "NOTARY_MODIFICATION:"
//!
//! The code is under the following license:
//!
//! Copyright (c) 2019 Axum Contributors
//!
//! Permission is hereby granted, free of charge, to any
//! person obtaining a copy of this software and associated
//! documentation files (the "Software"), to deal in the
//! Software without restriction, including without
//! limitation the rights to use, copy, modify, merge,
//! publish, distribute, sublicense, and/or sell copies of
//! the Software, and to permit persons to whom the Software
//! is furnished to do so, subject to the following
//! conditions:
//!
//! The above copyright notice and this permission notice
//! shall be included in all copies or substantial portions
//! of the Software.
//!
//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
//! ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
//! TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
//! PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
//! SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
//! CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
//! OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
//! IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
//! DEALINGS IN THE SOFTWARE.
//!
//!
//! Handle WebSocket connections.
//!
//! # Example
//!
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
//! }
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! while let Some(msg) = socket.recv().await {
//! let msg = if let Ok(msg) = msg {
//! msg
//! } else {
//! // client disconnected
//! return;
//! };
//!
//! if socket.send(msg).await.is_err() {
//! // client disconnected
//! return;
//! }
//! }
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! # Passing data and/or state to an `on_upgrade` callback
//!
//! ```
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get,
//! Router,
//! };
//!
//! #[derive(Clone)]
//! struct AppState {
//! // ...
//! }
//!
//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
//! ws.on_upgrade(|socket| handle_socket(socket, state))
//! }
//!
//! async fn handle_socket(socket: WebSocket, state: AppState) {
//! // ...
//! }
//!
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .with_state(AppState { /* ... */ });
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! # Read and write concurrently
//!
//! If you need to read and write concurrently from a [`WebSocket`] you can use
//! [`StreamExt::split`]:
//!
//! ```rust,no_run
//! use axum::{Error, extract::ws::{WebSocket, Message}};
//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! let (mut sender, mut receiver) = socket.split();
//!
//! tokio::spawn(write(sender));
//! tokio::spawn(read(receiver));
//! }
//!
//! async fn read(receiver: SplitStream<WebSocket>) {
//! // ...
//! }
//!
//! async fn write(sender: SplitSink<WebSocket, Message>) {
//! // ...
//! }
//! ```
//!
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
#![allow(unused)]
use self::rejection::*;
use async_trait::async_trait;
use async_tungstenite::{
tokio::TokioAdapter,
tungstenite::{
self as ts,
protocol::{self, WebSocketConfig},
},
WebSocketStream,
};
use axum::{
body::{self, Bytes},
extract::FromRequestParts,
response::Response,
Error,
};
use futures_util::{
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
};
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tracing::error;
/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
/// rejected.
///
/// See the [module docs](self) for an example.
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
}
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("config", &self.config)
.field("protocol", &self.protocol)
.field("sec_websocket_key", &self.sec_websocket_key)
.field("sec_websocket_protocol", &self.sec_websocket_protocol)
.finish_non_exhaustive()
}
}
impl<F> WebSocketUpgrade<F> {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
self
}
/// Set the maximum message size (defaults to 64 megabytes)
pub fn max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
/// Set the maximum frame size (defaults to 16 megabytes)
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
/// Allow server to accept unmasked frames (defaults to false)
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
self.config.accept_unmasked_frames = accept;
self
}
/// Set the known protocols.
///
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
/// return the protocol name.
///
/// The protocols should be listed in decreasing order of preference: if the client offers
/// multiple protocols that the server could support, the server will pick the first one in
/// this list.
///
/// # Examples
///
/// ```
/// use axum::{
/// extract::ws::{WebSocketUpgrade, WebSocket},
/// routing::get,
/// response::{IntoResponse, Response},
/// Router,
/// };
///
/// let app = Router::new().route("/ws", get(handler));
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
/// .on_upgrade(|socket| async {
/// // ...
/// })
/// }
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
if let Some(req_protocols) = self
.sec_websocket_protocol
.as_ref()
.and_then(|p| p.to_str().ok())
{
self.protocol = protocols
.into_iter()
// FIXME: This will often allocate a new `String` and so is less efficient than it
// could be. But that can't be fixed without breaking changes to the public API.
.map(Into::into)
.find(|protocol| {
req_protocols
.split(',')
.any(|req_protocol| req_protocol.trim() == protocol)
})
.map(|protocol| match protocol {
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
Cow::Borrowed(s) => HeaderValue::from_static(s),
});
}
self
}
/// Provide a callback to call if upgrading the connection fails.
///
/// The connection upgrade is performed in a background task. If that fails this callback
/// will be called.
///
/// By default any errors will be silently ignored.
///
/// # Example
///
/// ```
/// use axum::{
/// extract::{WebSocketUpgrade},
/// response::Response,
/// };
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.on_failed_upgrade(|error| {
/// report_error(error);
/// })
/// .on_upgrade(|socket| async { /* ... */ })
/// }
/// #
/// # fn report_error(_: axum::Error) {}
/// ```
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
where
C: OnFailedUpdgrade,
{
WebSocketUpgrade {
config: self.config,
protocol: self.protocol,
sec_websocket_key: self.sec_websocket_key,
on_upgrade: self.on_upgrade,
on_failed_upgrade: callback,
sec_websocket_protocol: self.sec_websocket_protocol,
}
}
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
#[must_use = "to setup the WebSocket connection, this response must be returned"]
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
C: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
F: OnFailedUpdgrade,
{
let on_upgrade = self.on_upgrade;
let config = self.config;
let on_failed_upgrade = self.on_failed_upgrade;
let protocol = self.protocol.clone();
tokio::spawn(async move {
let upgraded = match on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
error!("Something wrong with on_upgrade: {:?}", err);
on_failed_upgrade.call(Error::new(err));
return;
}
};
let socket = WebSocketStream::from_raw_socket(
// NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite
TokioAdapter::new(upgraded),
protocol::Role::Server,
Some(config),
)
.await;
let socket = WebSocket {
inner: socket,
protocol,
};
callback(socket).await;
});
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(body::boxed(body::Empty::new())).unwrap()
}
}
/// What to do when a connection upgrade fails.
///
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
pub trait OnFailedUpdgrade: Send + 'static {
/// Call the callback.
fn call(self, error: Error);
}
impl<F> OnFailedUpdgrade for F
where
F: FnOnce(Error) + Send + 'static,
{
fn call(self, error: Error) {
self(error)
}
}
/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
///
/// It simply ignores the error.
#[non_exhaustive]
#[derive(Debug)]
pub struct DefaultOnFailedUpdgrade;
impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
#[inline]
fn call(self, _error: Error) {}
}
#[async_trait]
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
where
S: Send + Sync,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}
let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();
let on_upgrade = parts
.extensions
.remove::<OnUpgrade>()
.ok_or(ConnectionNotUpgradable)?;
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self {
config: Default::default(),
protocol: None,
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
on_failed_upgrade: DefaultOnFailedUpdgrade,
})
}
}
pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = headers.get(&key) {
header
} else {
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}
/// A stream of WebSocket messages.
///
/// See [the module level documentation](self) for more details.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<TokioAdapter<Upgraded>>,
protocol: Option<HeaderValue>,
}
impl WebSocket {
/// Consume `self` and get the inner [`async_tungstenite::WebSocketStream`].
pub fn into_inner(self) -> WebSocketStream<TokioAdapter<Upgraded>> {
self.inner
}
/// Receive another message.
///
/// Returns `None` if the stream has closed.
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
self.next().await
}
/// Send a message.
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
self.inner
.send(msg.into_tungstenite())
.await
.map_err(Error::new)
}
/// Gracefully close this WebSocket.
pub async fn close(mut self) -> Result<(), Error> {
self.inner.close(None).await.map_err(Error::new)
}
/// Return the selected WebSocket subprotocol, if one has been chosen.
pub fn protocol(&self) -> Option<&HeaderValue> {
self.protocol.as_ref()
}
}
impl Stream for WebSocket {
type Item = Result<Message, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
Some(Ok(msg)) => {
if let Some(msg) = Message::from_tungstenite(msg) {
return Poll::Ready(Some(Ok(msg)));
}
}
Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
None => return Poll::Ready(None),
}
}
}
}
impl Sink<Message> for WebSocket {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.inner)
.start_send(item.into_tungstenite())
.map_err(Error::new)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
}
}
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
pub type CloseCode = u16;
/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CloseFrame<'t> {
/// The reason as a code.
pub code: CloseCode,
/// The reason as text string.
pub reason: Cow<'t, str>,
}
/// A WebSocket message.
//
// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
// Copyright (c) 2017 Alexey Galakhov
// Copyright (c) 2016 Jason Housley
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
/// A text WebSocket message
Text(String),
/// A binary WebSocket message
Binary(Vec<u8>),
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Ping messages will be automatically responded to by the server, so you do not have to worry
/// about dealing with them yourself.
Ping(Vec<u8>),
/// A pong message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Pong messages will be automatically sent to the client if a ping message is received, so
/// you do not have to worry about constructing them yourself unless you want to implement a
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
Pong(Vec<u8>),
/// A close message with the optional close frame.
Close(Option<CloseFrame<'static>>),
}
impl Message {
fn into_tungstenite(self) -> ts::Message {
match self {
Self::Text(text) => ts::Message::Text(text),
Self::Binary(binary) => ts::Message::Binary(binary),
Self::Ping(ping) => ts::Message::Ping(ping),
Self::Pong(pong) => ts::Message::Pong(pong),
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
code: ts::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
})),
Self::Close(None) => ts::Message::Close(None),
}
}
fn from_tungstenite(message: ts::Message) -> Option<Self> {
match message {
ts::Message::Text(text) => Some(Self::Text(text)),
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
}))),
ts::Message::Close(None) => Some(Self::Close(None)),
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
// https://github.com/snapview/tungstenite-rs/issues/268
ts::Message::Frame(_) => None,
}
}
/// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> {
match self {
Self::Text(string) => string.into_bytes(),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
Self::Close(None) => Vec::new(),
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
}
}
/// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String, Error> {
match self {
Self::Text(string) => Ok(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
.map_err(|err| err.utf8_error())
.map_err(Error::new)?),
Self::Close(None) => Ok(String::new()),
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
}
}
/// Attempt to get a &str from the WebSocket message,
/// this will try to convert binary data to utf8.
pub fn to_text(&self) -> Result<&str, Error> {
match *self {
Self::Text(ref string) => Ok(string),
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
Ok(std::str::from_utf8(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(""),
Self::Close(Some(ref frame)) => Ok(&frame.reason),
}
}
}
impl From<String> for Message {
fn from(string: String) -> Self {
Message::Text(string)
}
}
impl<'s> From<&'s str> for Message {
fn from(string: &'s str) -> Self {
Message::Text(string.into())
}
}
impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Self {
Message::Binary(data.into())
}
}
impl From<Vec<u8>> for Message {
fn from(data: Vec<u8>) -> Self {
Message::Binary(data)
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_data()
}
}
fn sign(key: &[u8]) -> HeaderValue {
use base64::engine::Engine as _;
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
pub mod rejection {
//! WebSocket specific rejections.
use axum_core::{
__composite_rejection as composite_rejection, __define_rejection as define_rejection,
};
define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `GET`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotGet;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidConnectionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Upgrade` header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidUpgradeHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidWebSocketVersionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Key` header missing"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct WebSocketKeyHeaderMissing;
}
define_rejection! {
#[status = UPGRADE_REQUIRED]
#[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
///
/// This rejection is returned if the connection cannot be upgraded for example if the
/// request is HTTP/1.0.
///
/// See [MDN] for more details about connection upgrades.
///
/// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
pub struct ConnectionNotUpgradable;
}
composite_rejection! {
/// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
///
/// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
ConnectionNotUpgradable,
}
}
}
pub mod close_code {
//! Constants for [`CloseCode`]s.
//!
//! [`CloseCode`]: super::CloseCode
/// Indicates a normal closure, meaning that the purpose for which the connection was
/// established has been fulfilled.
pub const NORMAL: u16 = 1000;
/// Indicates that an endpoint is "going away", such as a server going down or a browser having
/// navigated away from a page.
pub const AWAY: u16 = 1001;
/// Indicates that an endpoint is terminating the connection due to a protocol error.
pub const PROTOCOL: u16 = 1002;
/// Indicates that an endpoint is terminating the connection because it has received a type of
/// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
/// it receives a binary message).
pub const UNSUPPORTED: u16 = 1003;
/// Indicates that no status code was included in a closing frame.
pub const STATUS: u16 = 1005;
/// Indicates an abnormal closure.
pub const ABNORMAL: u16 = 1006;
/// Indicates that an endpoint is terminating the connection because it has received data
/// within a message that was not consistent with the type of the message (e.g., non-UTF-8
/// RFC3629 data within a text message).
pub const INVALID: u16 = 1007;
/// Indicates that an endpoint is terminating the connection because it has received a message
/// that violates its policy. This is a generic status code that can be returned when there is
/// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
/// hide specific details about the policy.
pub const POLICY: u16 = 1008;
/// Indicates that an endpoint is terminating the connection because it has received a message
/// that is too big for it to process.
pub const SIZE: u16 = 1009;
/// Indicates that an endpoint (client) is terminating the connection because it has expected
/// the server to negotiate one or more extension, but the server didn't return them in the
/// response message of the WebSocket handshake. The list of extensions that are needed should
/// be given as the reason for closing. Note that this status code is not used by the server,
/// because it can fail the WebSocket handshake instead.
pub const EXTENSION: u16 = 1010;
/// Indicates that a server is terminating the connection because it encountered an unexpected
/// condition that prevented it from fulfilling the request.
pub const ERROR: u16 = 1011;
/// Indicates that the server is restarting.
pub const RESTART: u16 = 1012;
/// Indicates that the server is overloaded and the client should either connect to a different
/// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
/// action.
pub const AGAIN: u16 = 1013;
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, routing::get, Router};
use http::{Request, Version};
use tower::ServiceExt;
#[tokio::test]
async fn rejects_http_1_0_requests() {
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let rejection = ws.unwrap_err();
assert!(matches!(
rejection,
WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
));
std::future::ready(())
});
let req = Request::builder()
.version(Version::HTTP_10)
.method(Method::GET)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
.header("sec-websocket-version", "13")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[allow(dead_code)]
fn default_on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
#[allow(dead_code)]
fn on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
}

View File

@@ -1,101 +0,0 @@
use async_trait::async_trait;
use axum::{
body,
extract::FromRequestParts,
http::{header, request::Parts, HeaderValue, StatusCode},
response::Response,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
use std::future::Future;
use tracing::{debug, error, info};
use crate::{domain::notary::NotaryGlobals, service::notary_service, NotaryServerError};
/// Custom extractor used to extract underlying TCP connection for TCP client — using the same upgrade primitives used by
/// the WebSocket implementation where the underlying TCP connection (wrapped in an Upgraded object) only gets polled as an OnUpgrade future
/// after the ongoing HTTP request is finished (ref: https://github.com/tokio-rs/axum/blob/a6a849bb5b96a2f641fa077fe76f70ad4d20341c/axum/src/extract/ws.rs#L122)
///
/// More info on the upgrade primitives: https://docs.rs/hyper/latest/hyper/upgrade/index.html
pub struct TcpUpgrade {
pub on_upgrade: OnUpgrade,
}
#[async_trait]
impl<S> FromRequestParts<S> for TcpUpgrade
where
S: Send + Sync,
{
type Rejection = NotaryServerError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let on_upgrade =
parts
.extensions
.remove::<OnUpgrade>()
.ok_or(NotaryServerError::BadProverRequest(
"Upgrade header is not set for TCP client".to_string(),
))?;
Ok(Self { on_upgrade })
}
}
impl TcpUpgrade {
/// Utility function to complete the http upgrade protocol by
/// (1) Return 101 switching protocol response to client to indicate the switching to TCP
/// (2) Spawn a new thread to await on the OnUpgrade object to claim the underlying TCP connection
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
C: FnOnce(Upgraded) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let on_upgrade = self.on_upgrade;
tokio::spawn(async move {
let upgraded = match on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
error!("Something wrong with upgrading HTTP: {:?}", err);
return;
}
};
callback(upgraded).await;
});
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const TCP: HeaderValue = HeaderValue::from_static("tcp");
let builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, TCP);
builder.body(body::boxed(body::Empty::new())).unwrap()
}
}
/// Perform notarization using the extracted tcp connection
pub async fn tcp_notarize(
stream: Upgraded,
notary_globals: NotaryGlobals,
session_id: String,
max_transcript_size: Option<usize>,
) {
debug!(?session_id, "Upgraded to tcp connection");
match notary_service(
stream,
&notary_globals.notary_signing_key,
&session_id,
max_transcript_size,
)
.await
{
Ok(_) => {
info!(?session_id, "Successful notarization using tcp!");
}
Err(err) => {
error!(?session_id, "Failed notarization using tcp: {err}");
}
}
}

View File

@@ -1,34 +0,0 @@
use tracing::{debug, error, info};
use ws_stream_tungstenite::WsStream;
use crate::{
domain::notary::NotaryGlobals,
service::{axum_websocket::WebSocket, notary_service},
};
/// Perform notarization using the established websocket connection
pub async fn websocket_notarize(
socket: WebSocket,
notary_globals: NotaryGlobals,
session_id: String,
max_transcript_size: Option<usize>,
) {
debug!(?session_id, "Upgraded to websocket connection");
// Wrap the websocket in WsStream so that we have AsyncRead and AsyncWrite implemented
let stream = WsStream::new(socket.into_inner());
match notary_service(
stream,
&notary_globals.notary_signing_key,
&session_id,
max_transcript_size,
)
.await
{
Ok(_) => {
info!(?session_id, "Successful notarization using websocket!");
}
Err(err) => {
error!(?session_id, "Failed notarization using websocket: {err}");
}
}
}

View File

@@ -1,52 +0,0 @@
use eyre::Result;
use serde::de::DeserializeOwned;
/// Parse a yaml configuration file into a struct
pub fn parse_config_file<T: DeserializeOwned>(location: &str) -> Result<T> {
let file = std::fs::File::open(location)?;
let config: T = serde_yaml::from_reader(file)?;
Ok(config)
}
/// Parse a csv file into a vec of structs
pub fn parse_csv_file<T: DeserializeOwned>(location: &str) -> Result<Vec<T>> {
let file = std::fs::File::open(location)?;
let mut reader = csv::Reader::from_reader(file);
let mut table: Vec<T> = Vec::new();
for result in reader.deserialize() {
let record: T = result?;
table.push(record);
}
Ok(table)
}
#[cfg(test)]
mod test {
use crate::{
config::NotaryServerProperties, domain::auth::AuthorizationWhitelistRecord,
util::parse_csv_file,
};
use super::{parse_config_file, Result};
#[test]
fn test_parse_config_file() {
let location = "./config/config.yaml";
let config: Result<NotaryServerProperties> = parse_config_file(location);
assert!(
config.is_ok(),
"Could not open file or read the file's values."
);
}
#[test]
fn test_parse_csv_file() {
let location = "./fixture/auth/whitelist.csv";
let table: Result<Vec<AuthorizationWhitelistRecord>> = parse_csv_file(location);
assert!(
table.is_ok(),
"Could not open csv or read the csv's values."
);
}
}