mirror of
https://github.com/maceip/sgx-tlsn-notary-server.git
synced 2026-01-09 20:57:59 -05:00
lib
This commit is contained in:
45
Cargo.toml
45
Cargo.toml
@@ -7,47 +7,8 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.67"
|
||||
async-tungstenite = { version = "0.22.2", features = ["tokio-native-tls"] }
|
||||
axum = { version = "0.6.18", features = ["ws"] }
|
||||
axum-core = "0.3.4"
|
||||
axum-macros = "0.3.8"
|
||||
base64 = "0.21.0"
|
||||
chrono = "0.4.31"
|
||||
csv = "1.3.0"
|
||||
eyre = "0.6.8"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3.28"
|
||||
http = "0.2.9"
|
||||
hyper = { version = "0.14", features = ["client", "http1", "server", "tcp"] }
|
||||
opentelemetry = { version = "0.19" }
|
||||
p256 = "0.13"
|
||||
rstest = "0.18"
|
||||
rustls = { version = "0.21" }
|
||||
rustls-pemfile = { version = "1.0.2" }
|
||||
serde = { version = "1.0.147", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
serde_yaml = "0.9.21"
|
||||
sha1 = "0.10"
|
||||
structopt = "0.3.26"
|
||||
thiserror = "1"
|
||||
tlsn-verifier = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
|
||||
tlsn-tls-core = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-rustls = { version = "0.24.1" }
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
tower = { version = "0.4.12", features = ["make"] }
|
||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||
eyre = "0.6.8"
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.19"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.4.1", features = ["v4", "fast-rng"] }
|
||||
ws_stream_tungstenite = { version = "0.10.0", features = ["tokio_io"] }
|
||||
|
||||
[dev-dependencies]
|
||||
# specify vendored feature to use statically linked copy of OpenSSL
|
||||
hyper-tls = { version = "0.5.0", features = ["vendored"] }
|
||||
tls-server-fixture = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
|
||||
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919" }
|
||||
tokio-native-tls = { version = "0.3.1", features = ["vendored"] }
|
||||
|
||||
structopt = "0.3.26"
|
||||
notary-server = { git = "https://github.com/tlsnotary/tlsn", rev = "ee17919"}
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct NotaryServerProperties {
|
||||
/// Name and address of the notary server
|
||||
pub server: ServerProperties,
|
||||
/// Setting for notarization
|
||||
pub notarization: NotarizationProperties,
|
||||
/// Setting for TLS connection between prover and notary
|
||||
pub tls: TLSProperties,
|
||||
/// File path of private key (in PEM format) used to sign the notarization
|
||||
pub notary_key: NotarySigningKeyProperties,
|
||||
/// Setting for logging/tracing
|
||||
pub tracing: TracingProperties,
|
||||
/// Setting for authorization
|
||||
pub authorization: AuthorizationProperties,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct AuthorizationProperties {
|
||||
/// Switch to turn on or off auth middleware
|
||||
pub enabled: bool,
|
||||
/// File path of the whitelist API key csv
|
||||
pub whitelist_csv_path: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct NotarizationProperties {
|
||||
/// Global limit for maximum transcript size in bytes
|
||||
pub max_transcript_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct ServerProperties {
|
||||
/// Used for testing purpose
|
||||
pub name: String,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct TLSProperties {
|
||||
/// Flag to turn on/off TLS between prover and notary (should always be turned on unless TLS is handled by external setup e.g. reverse proxy, cloud)
|
||||
pub enabled: bool,
|
||||
pub private_key_pem_path: String,
|
||||
pub certificate_pem_path: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct NotarySigningKeyProperties {
|
||||
pub private_key_pem_path: String,
|
||||
pub public_key_pem_path: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct TracingProperties {
|
||||
/// The minimum logging level, must be either of <https://docs.rs/tracing/latest/tracing/struct.Level.html#implementations>
|
||||
pub default_level: String,
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
pub mod auth;
|
||||
pub mod cli;
|
||||
pub mod notary;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Response object of the /info API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InfoResponse {
|
||||
/// Current version of notary-server
|
||||
pub version: String,
|
||||
/// Public key of the notary signing key
|
||||
pub public_key: String,
|
||||
/// Current git commit hash of notary-server
|
||||
pub git_commit_hash: String,
|
||||
/// Current git commit timestamp of notary-server
|
||||
pub git_commit_timestamp: String,
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Structure of each whitelisted record of the API key whitelist for authorization purpose
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub struct AuthorizationWhitelistRecord {
|
||||
pub name: String,
|
||||
pub api_key: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Convert whitelist data structure from vector to hashmap using api_key as the key to speed up lookup
|
||||
pub fn authorization_whitelist_vec_into_hashmap(
|
||||
authorization_whitelist: Vec<AuthorizationWhitelistRecord>,
|
||||
) -> HashMap<String, AuthorizationWhitelistRecord> {
|
||||
let mut hashmap = HashMap::new();
|
||||
authorization_whitelist.iter().for_each(|record| {
|
||||
hashmap.insert(record.api_key.clone(), record.to_owned());
|
||||
});
|
||||
hashmap
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
use structopt::StructOpt;
|
||||
|
||||
/// Fields loaded from the command line when launching this server.
|
||||
#[derive(Clone, Debug, StructOpt)]
|
||||
#[structopt(name = "Notary Server")]
|
||||
pub struct CliFields {
|
||||
/// Configuration file location
|
||||
#[structopt(long, default_value = "./config/config.yaml")]
|
||||
pub config_file: String,
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use p256::ecdsa::SigningKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::{config::NotarizationProperties, domain::auth::AuthorizationWhitelistRecord};
|
||||
|
||||
/// Response object of the /session API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NotarizationSessionResponse {
|
||||
/// Unique session id that is generated by notary and shared to prover
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
/// Request object of the /session API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NotarizationSessionRequest {
|
||||
pub client_type: ClientType,
|
||||
/// Maximum transcript size in bytes
|
||||
pub max_transcript_size: Option<usize>,
|
||||
}
|
||||
|
||||
/// Request query of the /notarize API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NotarizationRequestQuery {
|
||||
/// Session id that is returned from /session API
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
/// Types of client that the prover is using
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ClientType {
|
||||
/// Client that has access to the transport layer
|
||||
Tcp,
|
||||
/// Client that cannot directly access transport layer, e.g. browser extension
|
||||
Websocket,
|
||||
}
|
||||
|
||||
/// Session configuration data to be stored in temporary storage
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SessionData {
|
||||
pub max_transcript_size: Option<usize>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Global data that needs to be shared with the axum handlers
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NotaryGlobals {
|
||||
pub notary_signing_key: SigningKey,
|
||||
pub notarization_config: NotarizationProperties,
|
||||
/// A temporary storage to store configuration data, mainly used for WebSocket client
|
||||
pub store: Arc<Mutex<HashMap<String, SessionData>>>,
|
||||
/// Whitelist of API keys for authorization purpose
|
||||
pub authorization_whitelist: Option<Arc<HashMap<String, AuthorizationWhitelistRecord>>>,
|
||||
}
|
||||
|
||||
impl NotaryGlobals {
|
||||
pub fn new(
|
||||
notary_signing_key: SigningKey,
|
||||
notarization_config: NotarizationProperties,
|
||||
authorization_whitelist: Option<Arc<HashMap<String, AuthorizationWhitelistRecord>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
notary_signing_key,
|
||||
notarization_config,
|
||||
store: Default::default(),
|
||||
authorization_whitelist,
|
||||
}
|
||||
}
|
||||
}
|
||||
55
src/error.rs
55
src/error.rs
@@ -1,55 +0,0 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use eyre::Report;
|
||||
use std::error::Error;
|
||||
|
||||
use tlsn_verifier::tls::{VerifierConfigBuilderError, VerifierError};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum NotaryServerError {
|
||||
#[error(transparent)]
|
||||
Unexpected(#[from] Report),
|
||||
#[error("Failed to connect to prover: {0}")]
|
||||
Connection(String),
|
||||
#[error("Error occurred during notarization: {0}")]
|
||||
Notarization(Box<dyn Error + Send + 'static>),
|
||||
#[error("Invalid request from prover: {0}")]
|
||||
BadProverRequest(String),
|
||||
#[error("Unauthorized request from prover: {0}")]
|
||||
UnauthorizedProverRequest(String),
|
||||
}
|
||||
|
||||
impl From<VerifierError> for NotaryServerError {
|
||||
fn from(error: VerifierError) -> Self {
|
||||
Self::Notarization(Box::new(error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VerifierConfigBuilderError> for NotaryServerError {
|
||||
fn from(error: VerifierConfigBuilderError) -> Self {
|
||||
Self::Notarization(Box::new(error))
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait implementation to convert this error into an axum http response
|
||||
impl IntoResponse for NotaryServerError {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
bad_request_error @ NotaryServerError::BadProverRequest(_) => {
|
||||
(StatusCode::BAD_REQUEST, bad_request_error.to_string()).into_response()
|
||||
}
|
||||
unauthorized_request_error @ NotaryServerError::UnauthorizedProverRequest(_) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
unauthorized_request_error.to_string(),
|
||||
)
|
||||
.into_response(),
|
||||
_ => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Something wrong happened.",
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
21
src/lib.rs
21
src/lib.rs
@@ -1,21 +0,0 @@
|
||||
mod config;
|
||||
mod domain;
|
||||
mod error;
|
||||
mod middleware;
|
||||
mod server;
|
||||
mod server_tracing;
|
||||
mod service;
|
||||
mod util;
|
||||
|
||||
pub use config::{
|
||||
AuthorizationProperties, NotarizationProperties, NotaryServerProperties,
|
||||
NotarySigningKeyProperties, ServerProperties, TLSProperties, TracingProperties,
|
||||
};
|
||||
pub use domain::{
|
||||
cli::CliFields,
|
||||
notary::{ClientType, NotarizationSessionRequest, NotarizationSessionResponse},
|
||||
};
|
||||
pub use error::NotaryServerError;
|
||||
pub use server::{read_pem_file, run_server};
|
||||
pub use server_tracing::init_tracing;
|
||||
pub use util::parse_config_file;
|
||||
@@ -2,7 +2,7 @@ use eyre::{eyre, Result};
|
||||
use structopt::StructOpt;
|
||||
use tracing::debug;
|
||||
|
||||
use sgx_notary_server::{
|
||||
use notary_server::{
|
||||
init_tracing, parse_config_file, run_server, CliFields, NotaryServerError,
|
||||
NotaryServerProperties,
|
||||
};
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use axum::http::{header, request::Parts};
|
||||
use axum_core::extract::{FromRef, FromRequestParts};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, trace};
|
||||
|
||||
use crate::{
|
||||
domain::{auth::AuthorizationWhitelistRecord, notary::NotaryGlobals},
|
||||
NotaryServerError,
|
||||
};
|
||||
|
||||
/// Auth middleware to prevent DOS
|
||||
pub struct AuthorizationMiddleware;
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for AuthorizationMiddleware
|
||||
where
|
||||
NotaryGlobals: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = NotaryServerError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let notary_globals = NotaryGlobals::from_ref(state);
|
||||
let Some(whitelist) = notary_globals.authorization_whitelist else {
|
||||
trace!("Skipping authorization as whitelist is not set.");
|
||||
return Ok(Self);
|
||||
};
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|value| std::str::from_utf8(value.as_bytes()).ok());
|
||||
|
||||
match auth_header {
|
||||
Some(auth_header) => {
|
||||
if api_key_is_valid(auth_header, &whitelist) {
|
||||
trace!("Request authorized.");
|
||||
Ok(Self)
|
||||
} else {
|
||||
let err_msg = "Invalid API key.".to_string();
|
||||
error!(err_msg);
|
||||
Err(NotaryServerError::UnauthorizedProverRequest(err_msg))
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let err_msg = "Missing API key.".to_string();
|
||||
error!(err_msg);
|
||||
Err(NotaryServerError::UnauthorizedProverRequest(err_msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to check if an API key is in whitelist
|
||||
fn api_key_is_valid(
|
||||
api_key: &str,
|
||||
whitelist: &HashMap<String, AuthorizationWhitelistRecord>,
|
||||
) -> bool {
|
||||
whitelist.get(api_key).is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{api_key_is_valid, HashMap};
|
||||
use crate::domain::auth::{
|
||||
authorization_whitelist_vec_into_hashmap, AuthorizationWhitelistRecord,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn get_whitelist_fixture() -> HashMap<String, AuthorizationWhitelistRecord> {
|
||||
authorization_whitelist_vec_into_hashmap(vec![
|
||||
AuthorizationWhitelistRecord {
|
||||
name: "test-name-0".to_string(),
|
||||
api_key: "test-api-key-0".to_string(),
|
||||
created_at: "2023-10-18T07:38:53Z".to_string(),
|
||||
},
|
||||
AuthorizationWhitelistRecord {
|
||||
name: "test-name-1".to_string(),
|
||||
api_key: "test-api-key-1".to_string(),
|
||||
created_at: "2023-10-11T07:38:53Z".to_string(),
|
||||
},
|
||||
AuthorizationWhitelistRecord {
|
||||
name: "test-name-2".to_string(),
|
||||
api_key: "test-api-key-2".to_string(),
|
||||
created_at: "2022-10-11T07:38:53Z".to_string(),
|
||||
},
|
||||
])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_key_is_present() {
|
||||
let whitelist = get_whitelist_fixture();
|
||||
assert!(api_key_is_valid("test-api-key-0", &Arc::new(whitelist)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_key_is_absent() {
|
||||
let whitelist = get_whitelist_fixture();
|
||||
assert_eq!(
|
||||
api_key_is_valid("test-api-keY-0", &Arc::new(whitelist)),
|
||||
false
|
||||
);
|
||||
}
|
||||
}
|
||||
273
src/server.rs
273
src/server.rs
@@ -1,273 +0,0 @@
|
||||
use axum::{
|
||||
http::{Request, StatusCode},
|
||||
middleware::from_extractor_with_state,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use eyre::{ensure, eyre, Result};
|
||||
use futures_util::future::poll_fn;
|
||||
use hyper::server::{
|
||||
accept::Accept,
|
||||
conn::{AddrIncoming, Http},
|
||||
};
|
||||
use p256::{ecdsa::SigningKey, pkcs8::DecodePrivateKey};
|
||||
use rustls::{Certificate, PrivateKey, ServerConfig};
|
||||
use std::{
|
||||
fs::File as StdFile,
|
||||
io::BufReader,
|
||||
net::{IpAddr, SocketAddr},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
use tokio::{fs::File, net::TcpListener};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tower::MakeService;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::{
|
||||
config::{NotaryServerProperties, NotarySigningKeyProperties},
|
||||
domain::{
|
||||
auth::{authorization_whitelist_vec_into_hashmap, AuthorizationWhitelistRecord},
|
||||
notary::NotaryGlobals,
|
||||
InfoResponse,
|
||||
},
|
||||
error::NotaryServerError,
|
||||
middleware::AuthorizationMiddleware,
|
||||
service::{initialize, upgrade_protocol},
|
||||
util::parse_csv_file,
|
||||
};
|
||||
|
||||
/// Start a TCP server (with or without TLS) to accept notarization request for both TCP and WebSocket clients
|
||||
#[tracing::instrument(skip(config))]
|
||||
pub async fn run_server(config: &NotaryServerProperties) -> Result<(), NotaryServerError> {
|
||||
// Load the private key for notarized transcript signing
|
||||
let notary_signing_key = load_notary_signing_key(&config.notary_key).await?;
|
||||
// Build TLS acceptor if it is turned on
|
||||
let tls_acceptor = if !config.tls.enabled {
|
||||
debug!("Skipping TLS setup as it is turned off.");
|
||||
None
|
||||
} else {
|
||||
let (tls_private_key, tls_certificates) = load_tls_key_and_cert(
|
||||
&config.tls.private_key_pem_path,
|
||||
&config.tls.certificate_pem_path,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut server_config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(tls_certificates, tls_private_key)
|
||||
.map_err(|err| eyre!("Failed to instantiate notary server tls config: {err}"))?;
|
||||
|
||||
// Set the http protocols we support
|
||||
server_config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
let tls_config = Arc::new(server_config);
|
||||
Some(TlsAcceptor::from(tls_config))
|
||||
};
|
||||
|
||||
// Load the authorization whitelist csv if it is turned on
|
||||
let authorization_whitelist = if !config.authorization.enabled {
|
||||
debug!("Skipping authorization as it is turned off.");
|
||||
None
|
||||
} else {
|
||||
// Load the csv
|
||||
let whitelist_csv = parse_csv_file::<AuthorizationWhitelistRecord>(
|
||||
&config.authorization.whitelist_csv_path,
|
||||
)
|
||||
.map_err(|err| eyre!("Failed to parse authorization whitelist csv: {:?}", err))?;
|
||||
// Convert the whitelist record into hashmap for faster lookup
|
||||
Some(authorization_whitelist_vec_into_hashmap(whitelist_csv))
|
||||
};
|
||||
|
||||
let notary_address = SocketAddr::new(
|
||||
IpAddr::V4(config.server.host.parse().map_err(|err| {
|
||||
eyre!("Failed to parse notary host address from server config: {err}")
|
||||
})?),
|
||||
config.server.port,
|
||||
);
|
||||
let listener = TcpListener::bind(notary_address)
|
||||
.await
|
||||
.map_err(|err| eyre!("Failed to bind server address to tcp listener: {err}"))?;
|
||||
let mut listener = AddrIncoming::from_listener(listener)
|
||||
.map_err(|err| eyre!("Failed to build hyper tcp listener: {err}"))?;
|
||||
|
||||
info!("Listening for TCP traffic at {}", notary_address);
|
||||
|
||||
let protocol = Arc::new(Http::new());
|
||||
let notary_globals = NotaryGlobals::new(
|
||||
notary_signing_key,
|
||||
config.notarization.clone(),
|
||||
// Use Arc to prevent cloning the whitelist for every request
|
||||
authorization_whitelist.map(Arc::new),
|
||||
);
|
||||
|
||||
// Parameters needed for the info endpoint
|
||||
let public_key = std::fs::read_to_string(&config.notary_key.public_key_pem_path)
|
||||
.map_err(|err| eyre!("Failed to load notary public signing key for notarization: {err}"))?;
|
||||
let version = env!("CARGO_PKG_VERSION").to_string();
|
||||
let git_commit_hash = env!("GIT_COMMIT_HASH").to_string();
|
||||
let git_commit_timestamp = env!("GIT_COMMIT_TIMESTAMP").to_string();
|
||||
|
||||
let router = Router::new()
|
||||
.route(
|
||||
"/healthcheck",
|
||||
get(|| async move { (StatusCode::OK, "Ok").into_response() }),
|
||||
)
|
||||
.route(
|
||||
"/info",
|
||||
get(|| async move {
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(InfoResponse {
|
||||
version,
|
||||
public_key,
|
||||
git_commit_hash,
|
||||
git_commit_timestamp,
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}),
|
||||
)
|
||||
.route("/session", post(initialize))
|
||||
// Not applying auth middleware to /notarize endpoint for now as we can rely on our
|
||||
// short-lived session id generated from /session endpoint, as it is not possible
|
||||
// to use header for API key for websocket /notarize endpoint due to browser restriction
|
||||
// ref: https://stackoverflow.com/a/4361358; And putting it in url query param
|
||||
// seems to be more insecured: https://stackoverflow.com/questions/5517281/place-api-key-in-headers-or-url
|
||||
.route_layer(from_extractor_with_state::<
|
||||
AuthorizationMiddleware,
|
||||
NotaryGlobals,
|
||||
>(notary_globals.clone()))
|
||||
.route("/notarize", get(upgrade_protocol))
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(notary_globals);
|
||||
let mut app = router.into_make_service();
|
||||
|
||||
loop {
|
||||
// Poll and await for any incoming connection, ensure that all operations inside are infallible to prevent bringing down the server
|
||||
let (prover_address, stream) =
|
||||
match poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)).await {
|
||||
Some(Ok(connection)) => (connection.remote_addr(), connection),
|
||||
Some(Err(err)) => {
|
||||
error!("{}", NotaryServerError::Connection(err.to_string()));
|
||||
continue;
|
||||
}
|
||||
None => unreachable!("The poll_accept method should never return None"),
|
||||
};
|
||||
debug!(?prover_address, "Received a prover's TCP connection");
|
||||
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let protocol = protocol.clone();
|
||||
let service = MakeService::<_, Request<hyper::Body>>::make_service(&mut app, &stream);
|
||||
|
||||
// Spawn a new async task to handle the new connection
|
||||
tokio::spawn(async move {
|
||||
// When TLS is enabled
|
||||
if let Some(acceptor) = tls_acceptor {
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(stream) => {
|
||||
info!(
|
||||
?prover_address,
|
||||
"Accepted prover's TLS-secured TCP connection",
|
||||
);
|
||||
// Serve different requests using the same hyper protocol and axum router
|
||||
let _ = protocol
|
||||
// Can unwrap because it's infallible
|
||||
.serve_connection(stream, service.await.unwrap())
|
||||
// use with_upgrades to upgrade connection to websocket for websocket clients
|
||||
// and to extract tcp connection for tcp clients
|
||||
.with_upgrades()
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
?prover_address,
|
||||
"{}",
|
||||
NotaryServerError::Connection(err.to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// When TLS is disabled
|
||||
info!(?prover_address, "Accepted prover's TCP connection",);
|
||||
// Serve different requests using the same hyper protocol and axum router
|
||||
let _ = protocol
|
||||
// Can unwrap because it's infallible
|
||||
.serve_connection(stream, service.await.unwrap())
|
||||
// use with_upgrades to upgrade connection to websocket for websocket clients
|
||||
// and to extract tcp connection for tcp clients
|
||||
.with_upgrades()
|
||||
.await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Temporary function to load notary signing key from static file
|
||||
async fn load_notary_signing_key(config: &NotarySigningKeyProperties) -> Result<SigningKey> {
|
||||
debug!("Loading notary server's signing key");
|
||||
|
||||
let notary_signing_key = SigningKey::read_pkcs8_pem_file(&config.private_key_pem_path)
|
||||
.map_err(|err| eyre!("Failed to load notary signing key for notarization: {err}"))?;
|
||||
|
||||
debug!("Successfully loaded notary server's signing key!");
|
||||
Ok(notary_signing_key)
|
||||
}
|
||||
|
||||
/// Read a PEM-formatted file and return its buffer reader
|
||||
pub async fn read_pem_file(file_path: &str) -> Result<BufReader<StdFile>> {
|
||||
let key_file = File::open(file_path).await?.into_std().await;
|
||||
Ok(BufReader::new(key_file))
|
||||
}
|
||||
|
||||
/// Load notary tls private key and cert from static files
|
||||
async fn load_tls_key_and_cert(
|
||||
private_key_pem_path: &str,
|
||||
certificate_pem_path: &str,
|
||||
) -> Result<(PrivateKey, Vec<Certificate>)> {
|
||||
debug!("Loading notary server's tls private key and certificate");
|
||||
|
||||
let mut private_key_file_reader = read_pem_file(private_key_pem_path).await?;
|
||||
let mut private_keys = rustls_pemfile::pkcs8_private_keys(&mut private_key_file_reader)?;
|
||||
ensure!(
|
||||
private_keys.len() == 1,
|
||||
"More than 1 key found in the tls private key pem file"
|
||||
);
|
||||
let private_key = PrivateKey(private_keys.remove(0));
|
||||
|
||||
let mut certificate_file_reader = read_pem_file(certificate_pem_path).await?;
|
||||
let certificates = rustls_pemfile::certs(&mut certificate_file_reader)?
|
||||
.into_iter()
|
||||
.map(Certificate)
|
||||
.collect();
|
||||
|
||||
debug!("Successfully loaded notary server's tls private key and certificate!");
|
||||
Ok((private_key, certificates))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_notary_key_and_cert() {
|
||||
let private_key_pem_path = "./fixture/tls/notary.key";
|
||||
let certificate_pem_path = "./fixture/tls/notary.crt";
|
||||
let result: Result<(PrivateKey, Vec<Certificate>)> =
|
||||
load_tls_key_and_cert(private_key_pem_path, certificate_pem_path).await;
|
||||
assert!(result.is_ok(), "Could not load tls private key and cert");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_notary_signing_key() {
|
||||
let config = NotarySigningKeyProperties {
|
||||
private_key_pem_path: "./fixture/notary/notary.key".to_string(),
|
||||
public_key_pem_path: "./fixture/notary/notary.pub".to_string(),
|
||||
};
|
||||
let result: Result<SigningKey> = load_notary_signing_key(&config).await;
|
||||
assert!(result.is_ok(), "Could not load notary private key");
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
use eyre::Result;
|
||||
use opentelemetry::{
|
||||
global,
|
||||
sdk::{export::trace::stdout, propagation::TraceContextPropagator},
|
||||
};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry};
|
||||
|
||||
use crate::config::NotaryServerProperties;
|
||||
|
||||
pub fn init_tracing(config: &NotaryServerProperties) -> Result<()> {
|
||||
// Create a new OpenTelemetry pipeline
|
||||
let tracer = stdout::new_pipeline().install_simple();
|
||||
|
||||
// Create a tracing layer with the configured tracer
|
||||
let tracing_layer = tracing_opentelemetry::layer().with_tracer(tracer);
|
||||
|
||||
// Set the log level
|
||||
let env_filter_layer = EnvFilter::new(&config.tracing.default_level);
|
||||
|
||||
// Format the log
|
||||
let format_layer = tracing_subscriber::fmt::layer()
|
||||
// Use a more compact, abbreviated log format
|
||||
.compact()
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true);
|
||||
|
||||
// Set up context propagation
|
||||
global::set_text_map_propagator(TraceContextPropagator::default());
|
||||
|
||||
Registry::default()
|
||||
.with(tracing_layer)
|
||||
.with(env_filter_layer)
|
||||
.with(format_layer)
|
||||
.try_init()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
178
src/service.rs
178
src/service.rs
@@ -1,178 +0,0 @@
|
||||
pub mod axum_websocket;
|
||||
pub mod tcp;
|
||||
pub mod websocket;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
extract::{rejection::JsonRejection, FromRequestParts, Query, State},
|
||||
http::{header, request::Parts, StatusCode},
|
||||
response::{IntoResponse, Json, Response},
|
||||
};
|
||||
use axum_macros::debug_handler;
|
||||
use chrono::Utc;
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
use tlsn_verifier::tls::{Verifier, VerifierConfig};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use tracing::{debug, error, info, trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
domain::notary::{
|
||||
NotarizationRequestQuery, NotarizationSessionRequest, NotarizationSessionResponse,
|
||||
NotaryGlobals, SessionData,
|
||||
},
|
||||
error::NotaryServerError,
|
||||
service::{
|
||||
axum_websocket::{header_eq, WebSocketUpgrade},
|
||||
tcp::{tcp_notarize, TcpUpgrade},
|
||||
websocket::websocket_notarize,
|
||||
},
|
||||
};
|
||||
|
||||
/// A wrapper enum to facilitate extracting TCP connection for either WebSocket or TCP clients,
|
||||
/// so that we can use a single endpoint and handler for notarization for both types of clients
|
||||
pub enum ProtocolUpgrade {
|
||||
Tcp(TcpUpgrade),
|
||||
Ws(WebSocketUpgrade),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for ProtocolUpgrade
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = NotaryServerError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract tcp connection for websocket client
|
||||
if header_eq(&parts.headers, header::UPGRADE, "websocket") {
|
||||
let extractor = WebSocketUpgrade::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
|
||||
return Ok(Self::Ws(extractor));
|
||||
// Extract tcp connection for tcp client
|
||||
} else if header_eq(&parts.headers, header::UPGRADE, "tcp") {
|
||||
let extractor = TcpUpgrade::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
|
||||
return Ok(Self::Tcp(extractor));
|
||||
} else {
|
||||
return Err(NotaryServerError::BadProverRequest(
|
||||
"Upgrade header is not set for client".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handler to upgrade protocol from http to either websocket or underlying tcp depending on the type of client
|
||||
/// the session_id parameter is also extracted here to fetch the configuration parameters
|
||||
/// that have been submitted in the previous request to /session made by the same client
|
||||
pub async fn upgrade_protocol(
|
||||
protocol_upgrade: ProtocolUpgrade,
|
||||
State(notary_globals): State<NotaryGlobals>,
|
||||
Query(params): Query<NotarizationRequestQuery>,
|
||||
) -> Response {
|
||||
info!("Received upgrade protocol request");
|
||||
let session_id = params.session_id;
|
||||
// Fetch the configuration data from the store using the session_id
|
||||
// This also removes the configuration data from the store as each session_id can only be used once
|
||||
let max_transcript_size = match notary_globals.store.lock().await.remove(&session_id) {
|
||||
Some(data) => data.max_transcript_size,
|
||||
None => {
|
||||
let err_msg = format!("Session id {} does not exist", session_id);
|
||||
error!(err_msg);
|
||||
return NotaryServerError::BadProverRequest(err_msg).into_response();
|
||||
}
|
||||
};
|
||||
// This completes the HTTP Upgrade request and returns a successful response to the client, meanwhile initiating the websocket or tcp connection
|
||||
match protocol_upgrade {
|
||||
ProtocolUpgrade::Ws(ws) => ws.on_upgrade(move |socket| {
|
||||
websocket_notarize(socket, notary_globals, session_id, max_transcript_size)
|
||||
}),
|
||||
ProtocolUpgrade::Tcp(tcp) => tcp.on_upgrade(move |stream| {
|
||||
tcp_notarize(stream, notary_globals, session_id, max_transcript_size)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handler to initialize and configure notarization for both TCP and WebSocket clients
|
||||
#[debug_handler(state = NotaryGlobals)]
|
||||
pub async fn initialize(
|
||||
State(notary_globals): State<NotaryGlobals>,
|
||||
payload: Result<Json<NotarizationSessionRequest>, JsonRejection>,
|
||||
) -> impl IntoResponse {
|
||||
info!(
|
||||
?payload,
|
||||
"Received request for initializing a notarization session"
|
||||
);
|
||||
|
||||
// Parse the body payload
|
||||
let payload = match payload {
|
||||
Ok(payload) => payload,
|
||||
Err(err) => {
|
||||
error!("Malformed payload submitted for initializing notarization: {err}");
|
||||
return NotaryServerError::BadProverRequest(err.to_string()).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure that the max_transcript_size submitted is not larger than the global max limit configured in notary server
|
||||
if payload.max_transcript_size > Some(notary_globals.notarization_config.max_transcript_size) {
|
||||
error!(
|
||||
"Max transcript size requested {:?} exceeds the maximum threshold {:?}",
|
||||
payload.max_transcript_size, notary_globals.notarization_config.max_transcript_size
|
||||
);
|
||||
return NotaryServerError::BadProverRequest(
|
||||
"Max transcript size requested exceeds the maximum threshold".to_string(),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let prover_session_id = Uuid::new_v4().to_string();
|
||||
|
||||
// Store the configuration data in a temporary store
|
||||
notary_globals.store.lock().await.insert(
|
||||
prover_session_id.clone(),
|
||||
SessionData {
|
||||
max_transcript_size: payload.max_transcript_size,
|
||||
created_at: Utc::now(),
|
||||
},
|
||||
);
|
||||
|
||||
trace!("Latest store state: {:?}", notary_globals.store);
|
||||
|
||||
// Return the session id in the response to the client
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(NotarizationSessionResponse {
|
||||
session_id: prover_session_id,
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Run the notarization
|
||||
pub async fn notary_service<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
socket: T,
|
||||
signing_key: &SigningKey,
|
||||
session_id: &str,
|
||||
max_transcript_size: Option<usize>,
|
||||
) -> Result<(), NotaryServerError> {
|
||||
debug!(?session_id, "Starting notarization...");
|
||||
|
||||
let mut config_builder = VerifierConfig::builder();
|
||||
|
||||
config_builder = config_builder.id(session_id);
|
||||
|
||||
if let Some(max_transcript_size) = max_transcript_size {
|
||||
config_builder = config_builder.max_transcript_size(max_transcript_size);
|
||||
}
|
||||
|
||||
let config = config_builder.build()?;
|
||||
|
||||
Verifier::new(config)
|
||||
.notarize::<_, Signature>(socket.compat(), signing_key)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,914 +0,0 @@
|
||||
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.6.19/axum/src/extract/ws.rs
|
||||
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
|
||||
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
|
||||
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
|
||||
//! to get AsyncRead and AsyncWrite implemented for the WebSocket. Any other modification is commented with the prefix "NOTARY_MODIFICATION:"
|
||||
//!
|
||||
//! The code is under the following license:
|
||||
//!
|
||||
//! Copyright (c) 2019 Axum Contributors
|
||||
//!
|
||||
//! Permission is hereby granted, free of charge, to any
|
||||
//! person obtaining a copy of this software and associated
|
||||
//! documentation files (the "Software"), to deal in the
|
||||
//! Software without restriction, including without
|
||||
//! limitation the rights to use, copy, modify, merge,
|
||||
//! publish, distribute, sublicense, and/or sell copies of
|
||||
//! the Software, and to permit persons to whom the Software
|
||||
//! is furnished to do so, subject to the following
|
||||
//! conditions:
|
||||
//!
|
||||
//! The above copyright notice and this permission notice
|
||||
//! shall be included in all copies or substantial portions
|
||||
//! of the Software.
|
||||
//!
|
||||
//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
||||
//! ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
//! TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
//! PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
||||
//! SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
//! CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
//! OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
||||
//! IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
//! DEALINGS IN THE SOFTWARE.
|
||||
//!
|
||||
//!
|
||||
//! Handle WebSocket connections.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{
|
||||
//! extract::ws::{WebSocketUpgrade, WebSocket},
|
||||
//! routing::get,
|
||||
//! response::{IntoResponse, Response},
|
||||
//! Router,
|
||||
//! };
|
||||
//!
|
||||
//! let app = Router::new().route("/ws", get(handler));
|
||||
//!
|
||||
//! async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
//! ws.on_upgrade(handle_socket)
|
||||
//! }
|
||||
//!
|
||||
//! async fn handle_socket(mut socket: WebSocket) {
|
||||
//! while let Some(msg) = socket.recv().await {
|
||||
//! let msg = if let Ok(msg) = msg {
|
||||
//! msg
|
||||
//! } else {
|
||||
//! // client disconnected
|
||||
//! return;
|
||||
//! };
|
||||
//!
|
||||
//! if socket.send(msg).await.is_err() {
|
||||
//! // client disconnected
|
||||
//! return;
|
||||
//! }
|
||||
//! }
|
||||
//! }
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! # Passing data and/or state to an `on_upgrade` callback
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{
|
||||
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
|
||||
//! response::Response,
|
||||
//! routing::get,
|
||||
//! Router,
|
||||
//! };
|
||||
//!
|
||||
//! #[derive(Clone)]
|
||||
//! struct AppState {
|
||||
//! // ...
|
||||
//! }
|
||||
//!
|
||||
//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
|
||||
//! ws.on_upgrade(|socket| handle_socket(socket, state))
|
||||
//! }
|
||||
//!
|
||||
//! async fn handle_socket(socket: WebSocket, state: AppState) {
|
||||
//! // ...
|
||||
//! }
|
||||
//!
|
||||
//! let app = Router::new()
|
||||
//! .route("/ws", get(handler))
|
||||
//! .with_state(AppState { /* ... */ });
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! # Read and write concurrently
|
||||
//!
|
||||
//! If you need to read and write concurrently from a [`WebSocket`] you can use
|
||||
//! [`StreamExt::split`]:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use axum::{Error, extract::ws::{WebSocket, Message}};
|
||||
//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
|
||||
//!
|
||||
//! async fn handle_socket(mut socket: WebSocket) {
|
||||
//! let (mut sender, mut receiver) = socket.split();
|
||||
//!
|
||||
//! tokio::spawn(write(sender));
|
||||
//! tokio::spawn(read(receiver));
|
||||
//! }
|
||||
//!
|
||||
//! async fn read(receiver: SplitStream<WebSocket>) {
|
||||
//! // ...
|
||||
//! }
|
||||
//!
|
||||
//! async fn write(sender: SplitSink<WebSocket, Message>) {
|
||||
//! // ...
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
|
||||
|
||||
#![allow(unused)]
|
||||
|
||||
use self::rejection::*;
|
||||
use async_trait::async_trait;
|
||||
use async_tungstenite::{
|
||||
tokio::TokioAdapter,
|
||||
tungstenite::{
|
||||
self as ts,
|
||||
protocol::{self, WebSocketConfig},
|
||||
},
|
||||
WebSocketStream,
|
||||
};
|
||||
use axum::{
|
||||
body::{self, Bytes},
|
||||
extract::FromRequestParts,
|
||||
response::Response,
|
||||
Error,
|
||||
};
|
||||
|
||||
use futures_util::{
|
||||
sink::{Sink, SinkExt},
|
||||
stream::{Stream, StreamExt},
|
||||
};
|
||||
use http::{
|
||||
header::{self, HeaderMap, HeaderName, HeaderValue},
|
||||
request::Parts,
|
||||
Method, StatusCode,
|
||||
};
|
||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
/// Extractor for establishing WebSocket connections.
|
||||
///
|
||||
/// Note: This extractor requires the request method to be `GET` so it should
|
||||
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
|
||||
/// rejected.
|
||||
///
|
||||
/// See the [module docs](self) for an example.
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||
pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
|
||||
config: WebSocketConfig,
|
||||
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
|
||||
protocol: Option<HeaderValue>,
|
||||
sec_websocket_key: HeaderValue,
|
||||
on_upgrade: OnUpgrade,
|
||||
on_failed_upgrade: F,
|
||||
sec_websocket_protocol: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WebSocketUpgrade")
|
||||
.field("config", &self.config)
|
||||
.field("protocol", &self.protocol)
|
||||
.field("sec_websocket_key", &self.sec_websocket_key)
|
||||
.field("sec_websocket_protocol", &self.sec_websocket_protocol)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> WebSocketUpgrade<F> {
|
||||
/// Set the size of the internal message send queue.
|
||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||
self.config.max_send_queue = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum message size (defaults to 64 megabytes)
|
||||
pub fn max_message_size(mut self, max: usize) -> Self {
|
||||
self.config.max_message_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum frame size (defaults to 16 megabytes)
|
||||
pub fn max_frame_size(mut self, max: usize) -> Self {
|
||||
self.config.max_frame_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Allow server to accept unmasked frames (defaults to false)
|
||||
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
|
||||
self.config.accept_unmasked_frames = accept;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the known protocols.
|
||||
///
|
||||
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
|
||||
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
|
||||
/// return the protocol name.
|
||||
///
|
||||
/// The protocols should be listed in decreasing order of preference: if the client offers
|
||||
/// multiple protocols that the server could support, the server will pick the first one in
|
||||
/// this list.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// extract::ws::{WebSocketUpgrade, WebSocket},
|
||||
/// routing::get,
|
||||
/// response::{IntoResponse, Response},
|
||||
/// Router,
|
||||
/// };
|
||||
///
|
||||
/// let app = Router::new().route("/ws", get(handler));
|
||||
///
|
||||
/// async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
|
||||
/// .on_upgrade(|socket| async {
|
||||
/// // ...
|
||||
/// })
|
||||
/// }
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
pub fn protocols<I>(mut self, protocols: I) -> Self
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: Into<Cow<'static, str>>,
|
||||
{
|
||||
if let Some(req_protocols) = self
|
||||
.sec_websocket_protocol
|
||||
.as_ref()
|
||||
.and_then(|p| p.to_str().ok())
|
||||
{
|
||||
self.protocol = protocols
|
||||
.into_iter()
|
||||
// FIXME: This will often allocate a new `String` and so is less efficient than it
|
||||
// could be. But that can't be fixed without breaking changes to the public API.
|
||||
.map(Into::into)
|
||||
.find(|protocol| {
|
||||
req_protocols
|
||||
.split(',')
|
||||
.any(|req_protocol| req_protocol.trim() == protocol)
|
||||
})
|
||||
.map(|protocol| match protocol {
|
||||
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
|
||||
Cow::Borrowed(s) => HeaderValue::from_static(s),
|
||||
});
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a callback to call if upgrading the connection fails.
|
||||
///
|
||||
/// The connection upgrade is performed in a background task. If that fails this callback
|
||||
/// will be called.
|
||||
///
|
||||
/// By default any errors will be silently ignored.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// extract::{WebSocketUpgrade},
|
||||
/// response::Response,
|
||||
/// };
|
||||
///
|
||||
/// async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
/// ws.on_failed_upgrade(|error| {
|
||||
/// report_error(error);
|
||||
/// })
|
||||
/// .on_upgrade(|socket| async { /* ... */ })
|
||||
/// }
|
||||
/// #
|
||||
/// # fn report_error(_: axum::Error) {}
|
||||
/// ```
|
||||
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
|
||||
where
|
||||
C: OnFailedUpdgrade,
|
||||
{
|
||||
WebSocketUpgrade {
|
||||
config: self.config,
|
||||
protocol: self.protocol,
|
||||
sec_websocket_key: self.sec_websocket_key,
|
||||
on_upgrade: self.on_upgrade,
|
||||
on_failed_upgrade: callback,
|
||||
sec_websocket_protocol: self.sec_websocket_protocol,
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize upgrading the connection and call the provided callback with
|
||||
/// the stream.
|
||||
#[must_use = "to setup the WebSocket connection, this response must be returned"]
|
||||
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
|
||||
where
|
||||
C: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
F: OnFailedUpdgrade,
|
||||
{
|
||||
let on_upgrade = self.on_upgrade;
|
||||
let config = self.config;
|
||||
let on_failed_upgrade = self.on_failed_upgrade;
|
||||
|
||||
let protocol = self.protocol.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = match on_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(err) => {
|
||||
error!("Something wrong with on_upgrade: {:?}", err);
|
||||
on_failed_upgrade.call(Error::new(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
let socket = WebSocketStream::from_raw_socket(
|
||||
// NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite
|
||||
TokioAdapter::new(upgraded),
|
||||
protocol::Role::Server,
|
||||
Some(config),
|
||||
)
|
||||
.await;
|
||||
let socket = WebSocket {
|
||||
inner: socket,
|
||||
protocol,
|
||||
};
|
||||
callback(socket).await;
|
||||
});
|
||||
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(header::CONNECTION, UPGRADE)
|
||||
.header(header::UPGRADE, WEBSOCKET)
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(self.sec_websocket_key.as_bytes()),
|
||||
);
|
||||
|
||||
if let Some(protocol) = self.protocol {
|
||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
|
||||
builder.body(body::boxed(body::Empty::new())).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// What to do when a connection upgrade fails.
|
||||
///
|
||||
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
|
||||
pub trait OnFailedUpdgrade: Send + 'static {
|
||||
/// Call the callback.
|
||||
fn call(self, error: Error);
|
||||
}
|
||||
|
||||
impl<F> OnFailedUpdgrade for F
|
||||
where
|
||||
F: FnOnce(Error) + Send + 'static,
|
||||
{
|
||||
fn call(self, error: Error) {
|
||||
self(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
|
||||
///
|
||||
/// It simply ignores the error.
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug)]
|
||||
pub struct DefaultOnFailedUpdgrade;
|
||||
|
||||
impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
|
||||
#[inline]
|
||||
fn call(self, _error: Error) {}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = WebSocketUpgradeRejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
if parts.method != Method::GET {
|
||||
return Err(MethodNotGet.into());
|
||||
}
|
||||
|
||||
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
|
||||
return Err(InvalidConnectionHeader.into());
|
||||
}
|
||||
|
||||
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
|
||||
return Err(InvalidUpgradeHeader.into());
|
||||
}
|
||||
|
||||
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
|
||||
return Err(InvalidWebSocketVersionHeader.into());
|
||||
}
|
||||
|
||||
let sec_websocket_key = parts
|
||||
.headers
|
||||
.get(header::SEC_WEBSOCKET_KEY)
|
||||
.ok_or(WebSocketKeyHeaderMissing)?
|
||||
.clone();
|
||||
|
||||
let on_upgrade = parts
|
||||
.extensions
|
||||
.remove::<OnUpgrade>()
|
||||
.ok_or(ConnectionNotUpgradable)?;
|
||||
|
||||
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
|
||||
|
||||
Ok(Self {
|
||||
config: Default::default(),
|
||||
protocol: None,
|
||||
sec_websocket_key,
|
||||
on_upgrade,
|
||||
sec_websocket_protocol,
|
||||
on_failed_upgrade: DefaultOnFailedUpdgrade,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
|
||||
if let Some(header) = headers.get(&key) {
|
||||
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
|
||||
let header = if let Some(header) = headers.get(&key) {
|
||||
header
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
|
||||
header.to_ascii_lowercase().contains(value)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of WebSocket messages.
|
||||
///
|
||||
/// See [the module level documentation](self) for more details.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
inner: WebSocketStream<TokioAdapter<Upgraded>>,
|
||||
protocol: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
impl WebSocket {
|
||||
/// Consume `self` and get the inner [`async_tungstenite::WebSocketStream`].
|
||||
pub fn into_inner(self) -> WebSocketStream<TokioAdapter<Upgraded>> {
|
||||
self.inner
|
||||
}
|
||||
|
||||
/// Receive another message.
|
||||
///
|
||||
/// Returns `None` if the stream has closed.
|
||||
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
|
||||
self.next().await
|
||||
}
|
||||
|
||||
/// Send a message.
|
||||
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
|
||||
self.inner
|
||||
.send(msg.into_tungstenite())
|
||||
.await
|
||||
.map_err(Error::new)
|
||||
}
|
||||
|
||||
/// Gracefully close this WebSocket.
|
||||
pub async fn close(mut self) -> Result<(), Error> {
|
||||
self.inner.close(None).await.map_err(Error::new)
|
||||
}
|
||||
|
||||
/// Return the selected WebSocket subprotocol, if one has been chosen.
|
||||
pub fn protocol(&self) -> Option<&HeaderValue> {
|
||||
self.protocol.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WebSocket {
|
||||
type Item = Result<Message, Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
|
||||
Some(Ok(msg)) => {
|
||||
if let Some(msg) = Message::from_tungstenite(msg) {
|
||||
return Poll::Ready(Some(Ok(msg)));
|
||||
}
|
||||
}
|
||||
Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for WebSocket {
|
||||
type Error = Error;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.inner)
|
||||
.start_send(item.into_tungstenite())
|
||||
.map_err(Error::new)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
|
||||
}
|
||||
}
|
||||
|
||||
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
|
||||
pub type CloseCode = u16;
|
||||
|
||||
/// A struct representing the close command.
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct CloseFrame<'t> {
|
||||
/// The reason as a code.
|
||||
pub code: CloseCode,
|
||||
/// The reason as text string.
|
||||
pub reason: Cow<'t, str>,
|
||||
}
|
||||
|
||||
/// A WebSocket message.
|
||||
//
|
||||
// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
|
||||
// Copyright (c) 2017 Alexey Galakhov
|
||||
// Copyright (c) 2016 Jason Housley
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
// THE SOFTWARE.
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum Message {
|
||||
/// A text WebSocket message
|
||||
Text(String),
|
||||
/// A binary WebSocket message
|
||||
Binary(Vec<u8>),
|
||||
/// A ping message with the specified payload
|
||||
///
|
||||
/// The payload here must have a length less than 125 bytes.
|
||||
///
|
||||
/// Ping messages will be automatically responded to by the server, so you do not have to worry
|
||||
/// about dealing with them yourself.
|
||||
Ping(Vec<u8>),
|
||||
/// A pong message with the specified payload
|
||||
///
|
||||
/// The payload here must have a length less than 125 bytes.
|
||||
///
|
||||
/// Pong messages will be automatically sent to the client if a ping message is received, so
|
||||
/// you do not have to worry about constructing them yourself unless you want to implement a
|
||||
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
|
||||
Pong(Vec<u8>),
|
||||
/// A close message with the optional close frame.
|
||||
Close(Option<CloseFrame<'static>>),
|
||||
}
|
||||
|
||||
impl Message {
|
||||
fn into_tungstenite(self) -> ts::Message {
|
||||
match self {
|
||||
Self::Text(text) => ts::Message::Text(text),
|
||||
Self::Binary(binary) => ts::Message::Binary(binary),
|
||||
Self::Ping(ping) => ts::Message::Ping(ping),
|
||||
Self::Pong(pong) => ts::Message::Pong(pong),
|
||||
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
|
||||
code: ts::protocol::frame::coding::CloseCode::from(close.code),
|
||||
reason: close.reason,
|
||||
})),
|
||||
Self::Close(None) => ts::Message::Close(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_tungstenite(message: ts::Message) -> Option<Self> {
|
||||
match message {
|
||||
ts::Message::Text(text) => Some(Self::Text(text)),
|
||||
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
|
||||
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
|
||||
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
|
||||
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
|
||||
code: close.code.into(),
|
||||
reason: close.reason,
|
||||
}))),
|
||||
ts::Message::Close(None) => Some(Self::Close(None)),
|
||||
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
|
||||
// https://github.com/snapview/tungstenite-rs/issues/268
|
||||
ts::Message::Frame(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume the WebSocket and return it as binary data.
|
||||
pub fn into_data(self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Text(string) => string.into_bytes(),
|
||||
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
|
||||
Self::Close(None) => Vec::new(),
|
||||
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to consume the WebSocket message and convert it to a String.
|
||||
pub fn into_text(self) -> Result<String, Error> {
|
||||
match self {
|
||||
Self::Text(string) => Ok(string),
|
||||
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
|
||||
.map_err(|err| err.utf8_error())
|
||||
.map_err(Error::new)?),
|
||||
Self::Close(None) => Ok(String::new()),
|
||||
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to get a &str from the WebSocket message,
|
||||
/// this will try to convert binary data to utf8.
|
||||
pub fn to_text(&self) -> Result<&str, Error> {
|
||||
match *self {
|
||||
Self::Text(ref string) => Ok(string),
|
||||
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
|
||||
Ok(std::str::from_utf8(data).map_err(Error::new)?)
|
||||
}
|
||||
Self::Close(None) => Ok(""),
|
||||
Self::Close(Some(ref frame)) => Ok(&frame.reason),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Message {
|
||||
fn from(string: String) -> Self {
|
||||
Message::Text(string)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'s> From<&'s str> for Message {
|
||||
fn from(string: &'s str) -> Self {
|
||||
Message::Text(string.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'b> From<&'b [u8]> for Message {
|
||||
fn from(data: &'b [u8]) -> Self {
|
||||
Message::Binary(data.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for Message {
|
||||
fn from(data: Vec<u8>) -> Self {
|
||||
Message::Binary(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Message> for Vec<u8> {
|
||||
fn from(msg: Message) -> Self {
|
||||
msg.into_data()
|
||||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
use base64::engine::Engine as _;
|
||||
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||
let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
|
||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||
}
|
||||
|
||||
pub mod rejection {
|
||||
//! WebSocket specific rejections.
|
||||
|
||||
use axum_core::{
|
||||
__composite_rejection as composite_rejection, __define_rejection as define_rejection,
|
||||
};
|
||||
|
||||
define_rejection! {
|
||||
#[status = METHOD_NOT_ALLOWED]
|
||||
#[body = "Request method must be `GET`"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct MethodNotGet;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Connection header did not include 'upgrade'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidConnectionHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Upgrade` header did not include 'websocket'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidUpgradeHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidWebSocketVersionHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Sec-WebSocket-Key` header missing"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct WebSocketKeyHeaderMissing;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = UPGRADE_REQUIRED]
|
||||
#[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
///
|
||||
/// This rejection is returned if the connection cannot be upgraded for example if the
|
||||
/// request is HTTP/1.0.
|
||||
///
|
||||
/// See [MDN] for more details about connection upgrades.
|
||||
///
|
||||
/// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
|
||||
pub struct ConnectionNotUpgradable;
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
///
|
||||
/// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
|
||||
/// extractor can fail.
|
||||
pub enum WebSocketUpgradeRejection {
|
||||
MethodNotGet,
|
||||
InvalidConnectionHeader,
|
||||
InvalidUpgradeHeader,
|
||||
InvalidWebSocketVersionHeader,
|
||||
WebSocketKeyHeaderMissing,
|
||||
ConnectionNotUpgradable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod close_code {
|
||||
//! Constants for [`CloseCode`]s.
|
||||
//!
|
||||
//! [`CloseCode`]: super::CloseCode
|
||||
|
||||
/// Indicates a normal closure, meaning that the purpose for which the connection was
|
||||
/// established has been fulfilled.
|
||||
pub const NORMAL: u16 = 1000;
|
||||
|
||||
/// Indicates that an endpoint is "going away", such as a server going down or a browser having
|
||||
/// navigated away from a page.
|
||||
pub const AWAY: u16 = 1001;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection due to a protocol error.
|
||||
pub const PROTOCOL: u16 = 1002;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received a type of
|
||||
/// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
|
||||
/// it receives a binary message).
|
||||
pub const UNSUPPORTED: u16 = 1003;
|
||||
|
||||
/// Indicates that no status code was included in a closing frame.
|
||||
pub const STATUS: u16 = 1005;
|
||||
|
||||
/// Indicates an abnormal closure.
|
||||
pub const ABNORMAL: u16 = 1006;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received data
|
||||
/// within a message that was not consistent with the type of the message (e.g., non-UTF-8
|
||||
/// RFC3629 data within a text message).
|
||||
pub const INVALID: u16 = 1007;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received a message
|
||||
/// that violates its policy. This is a generic status code that can be returned when there is
|
||||
/// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
|
||||
/// hide specific details about the policy.
|
||||
pub const POLICY: u16 = 1008;
|
||||
|
||||
/// Indicates that an endpoint is terminating the connection because it has received a message
|
||||
/// that is too big for it to process.
|
||||
pub const SIZE: u16 = 1009;
|
||||
|
||||
/// Indicates that an endpoint (client) is terminating the connection because it has expected
|
||||
/// the server to negotiate one or more extension, but the server didn't return them in the
|
||||
/// response message of the WebSocket handshake. The list of extensions that are needed should
|
||||
/// be given as the reason for closing. Note that this status code is not used by the server,
|
||||
/// because it can fail the WebSocket handshake instead.
|
||||
pub const EXTENSION: u16 = 1010;
|
||||
|
||||
/// Indicates that a server is terminating the connection because it encountered an unexpected
|
||||
/// condition that prevented it from fulfilling the request.
|
||||
pub const ERROR: u16 = 1011;
|
||||
|
||||
/// Indicates that the server is restarting.
|
||||
pub const RESTART: u16 = 1012;
|
||||
|
||||
/// Indicates that the server is overloaded and the client should either connect to a different
|
||||
/// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
|
||||
/// action.
|
||||
pub const AGAIN: u16 = 1013;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::{body::Body, routing::get, Router};
|
||||
use http::{Request, Version};
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_http_1_0_requests() {
|
||||
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
|
||||
let rejection = ws.unwrap_err();
|
||||
assert!(matches!(
|
||||
rejection,
|
||||
WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
|
||||
));
|
||||
std::future::ready(())
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.version(Version::HTTP_10)
|
||||
.method(Method::GET)
|
||||
.header("upgrade", "websocket")
|
||||
.header("connection", "Upgrade")
|
||||
.header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
|
||||
.header("sec-websocket-version", "13")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let res = svc.oneshot(req).await.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn default_on_failed_upgrade() {
|
||||
async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_upgrade(|_| async {})
|
||||
}
|
||||
let _: Router = Router::new().route("/", get(handler));
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn on_failed_upgrade() {
|
||||
async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
|
||||
.on_upgrade(|_| async {})
|
||||
}
|
||||
let _: Router = Router::new().route("/", get(handler));
|
||||
}
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body,
|
||||
extract::FromRequestParts,
|
||||
http::{header, request::Parts, HeaderValue, StatusCode},
|
||||
response::Response,
|
||||
};
|
||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use std::future::Future;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::{domain::notary::NotaryGlobals, service::notary_service, NotaryServerError};
|
||||
|
||||
/// Custom extractor used to extract underlying TCP connection for TCP client — using the same upgrade primitives used by
|
||||
/// the WebSocket implementation where the underlying TCP connection (wrapped in an Upgraded object) only gets polled as an OnUpgrade future
|
||||
/// after the ongoing HTTP request is finished (ref: https://github.com/tokio-rs/axum/blob/a6a849bb5b96a2f641fa077fe76f70ad4d20341c/axum/src/extract/ws.rs#L122)
|
||||
///
|
||||
/// More info on the upgrade primitives: https://docs.rs/hyper/latest/hyper/upgrade/index.html
|
||||
pub struct TcpUpgrade {
|
||||
pub on_upgrade: OnUpgrade,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for TcpUpgrade
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = NotaryServerError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let on_upgrade =
|
||||
parts
|
||||
.extensions
|
||||
.remove::<OnUpgrade>()
|
||||
.ok_or(NotaryServerError::BadProverRequest(
|
||||
"Upgrade header is not set for TCP client".to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self { on_upgrade })
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpUpgrade {
|
||||
/// Utility function to complete the http upgrade protocol by
|
||||
/// (1) Return 101 switching protocol response to client to indicate the switching to TCP
|
||||
/// (2) Spawn a new thread to await on the OnUpgrade object to claim the underlying TCP connection
|
||||
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
|
||||
where
|
||||
C: FnOnce(Upgraded) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let on_upgrade = self.on_upgrade;
|
||||
tokio::spawn(async move {
|
||||
let upgraded = match on_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(err) => {
|
||||
error!("Something wrong with upgrading HTTP: {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
callback(upgraded).await;
|
||||
});
|
||||
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const TCP: HeaderValue = HeaderValue::from_static("tcp");
|
||||
|
||||
let builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(header::CONNECTION, UPGRADE)
|
||||
.header(header::UPGRADE, TCP);
|
||||
|
||||
builder.body(body::boxed(body::Empty::new())).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform notarization using the extracted tcp connection
|
||||
pub async fn tcp_notarize(
|
||||
stream: Upgraded,
|
||||
notary_globals: NotaryGlobals,
|
||||
session_id: String,
|
||||
max_transcript_size: Option<usize>,
|
||||
) {
|
||||
debug!(?session_id, "Upgraded to tcp connection");
|
||||
match notary_service(
|
||||
stream,
|
||||
¬ary_globals.notary_signing_key,
|
||||
&session_id,
|
||||
max_transcript_size,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!(?session_id, "Successful notarization using tcp!");
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?session_id, "Failed notarization using tcp: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
use tracing::{debug, error, info};
|
||||
use ws_stream_tungstenite::WsStream;
|
||||
|
||||
use crate::{
|
||||
domain::notary::NotaryGlobals,
|
||||
service::{axum_websocket::WebSocket, notary_service},
|
||||
};
|
||||
|
||||
/// Perform notarization using the established websocket connection
|
||||
pub async fn websocket_notarize(
|
||||
socket: WebSocket,
|
||||
notary_globals: NotaryGlobals,
|
||||
session_id: String,
|
||||
max_transcript_size: Option<usize>,
|
||||
) {
|
||||
debug!(?session_id, "Upgraded to websocket connection");
|
||||
// Wrap the websocket in WsStream so that we have AsyncRead and AsyncWrite implemented
|
||||
let stream = WsStream::new(socket.into_inner());
|
||||
match notary_service(
|
||||
stream,
|
||||
¬ary_globals.notary_signing_key,
|
||||
&session_id,
|
||||
max_transcript_size,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!(?session_id, "Successful notarization using websocket!");
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?session_id, "Failed notarization using websocket: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
52
src/util.rs
52
src/util.rs
@@ -1,52 +0,0 @@
|
||||
use eyre::Result;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
/// Parse a yaml configuration file into a struct
|
||||
pub fn parse_config_file<T: DeserializeOwned>(location: &str) -> Result<T> {
|
||||
let file = std::fs::File::open(location)?;
|
||||
let config: T = serde_yaml::from_reader(file)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Parse a csv file into a vec of structs
|
||||
pub fn parse_csv_file<T: DeserializeOwned>(location: &str) -> Result<Vec<T>> {
|
||||
let file = std::fs::File::open(location)?;
|
||||
let mut reader = csv::Reader::from_reader(file);
|
||||
let mut table: Vec<T> = Vec::new();
|
||||
for result in reader.deserialize() {
|
||||
let record: T = result?;
|
||||
table.push(record);
|
||||
}
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use crate::{
|
||||
config::NotaryServerProperties, domain::auth::AuthorizationWhitelistRecord,
|
||||
util::parse_csv_file,
|
||||
};
|
||||
|
||||
use super::{parse_config_file, Result};
|
||||
|
||||
#[test]
|
||||
fn test_parse_config_file() {
|
||||
let location = "./config/config.yaml";
|
||||
let config: Result<NotaryServerProperties> = parse_config_file(location);
|
||||
assert!(
|
||||
config.is_ok(),
|
||||
"Could not open file or read the file's values."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_csv_file() {
|
||||
let location = "./fixture/auth/whitelist.csv";
|
||||
let table: Result<Vec<AuthorizationWhitelistRecord>> = parse_csv_file(location);
|
||||
assert!(
|
||||
table.is_ok(),
|
||||
"Could not open csv or read the csv's values."
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user