Update hyper and tlsn version for interactive verifier.

This commit is contained in:
yuroitaki
2024-09-03 10:59:44 +08:00
parent 6dbc409663
commit f801e164b9
5 changed files with 104 additions and 82 deletions

View File

@@ -4,8 +4,13 @@ version = "0.1.0"
edition = "2021"
[dependencies]
tracing = "0.1.40"
tracing-subscriber = { version ="0.3.18", features = ["env-filter"] }
async-tungstenite = { version = "0.25", features = ["tokio-runtime"] }
futures = "0.3"
http = "1.1"
http-body-util = "0.1"
hyper = {version = "1.1", features = ["client", "http1"]}
hyper-util = {version = "0.1", features = ["full"]}
regex = "1.10.3"
tokio = {version = "1", features = [
"rt",
"rt-multi-thread",
@@ -15,17 +20,10 @@ tokio = {version = "1", features = [
"fs",
]}
tokio-util = { version = "0.7", features = ["compat"] }
http = "0.2.9"
hyper-util = {version = "0.1", features = ["full"]}
http-body-util = "0.1"
hyper = {version = "1.1", features = ["client", "http1"]}
tracing = "0.1.40"
tracing-subscriber = { version ="0.3.18", features = ["env-filter"] }
uuid = { version = "1.4.1", features = ["v4", "fast-rng"] }
regex = "1.10.3"
futures = "0.3"
async-tungstenite = { version = "0.22.2", features = ["tokio-runtime"] }
ws_stream_tungstenite = { version = "0.10.0", features = ["tokio_io"] }
ws_stream_tungstenite = { version = "0.13", features = ["tokio_io"] }
tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.4", package = "tlsn-core" }
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.4", package = "tlsn-prover", features = [
"tracing",
] }
tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.6", package = "tlsn-core" }
tlsn-prover = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.6", package = "tlsn-prover" }

View File

@@ -1,5 +1,4 @@
use async_tungstenite::{tokio::connect_async_with_config, tungstenite::protocol::WebSocketConfig};
use futures::AsyncWriteExt;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
@@ -107,7 +106,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.await
.unwrap();
let connection_task = tokio::spawn(connection.without_shutdown());
tokio::spawn(connection);
// MPC-TLS: Send Request and wait for Response.
info!("Send Request and wait for Response");

View File

@@ -4,6 +4,19 @@ version = "0.1.0"
edition = "2021"
[dependencies]
async-trait = "0.1.67"
async-tungstenite = { version = "0.25", features = ["tokio-native-tls"] }
axum = { version = "0.7", features = ["ws"] }
axum-core = "0.4"
base64 = "0.21.0"
eyre = "0.6.12"
futures-util = "0.3.28"
http = { version = "1.1" }
http-body-util = { version = "0.1" }
hyper = { version = "1.1", features = ["client", "http1", "server"] }
hyper-util = { version = "0.1", features = ["full"] }
serde = { version = "1.0.147", features = ["derive"] }
sha1 = "0.10"
tokio = {version = "1", features = [
"rt",
"rt-multi-thread",
@@ -13,23 +26,11 @@ tokio = {version = "1", features = [
"fs",
]}
tokio-util = { version = "0.7", features = ["compat"] }
tracing = "0.1.40"
eyre = "0.6.12"
tracing-subscriber = { version ="0.3.18", features = ["env-filter"] }
hyper = { version = "0.14", features = ["client", "http1", "server", "tcp"] }
axum = { version = "0.6.18", features = ["ws"] }
serde = { version = "1.0.147", features = ["derive"] }
futures-util = "0.3.28"
http = "0.2.9"
sha1 = "0.10"
async-trait = "0.1.67"
async-tungstenite = { version = "0.22.2", features = ["tokio-native-tls"] }
axum-core = "0.3.4"
base64 = "0.21.0"
tower = { version = "0.4.12", features = ["make"] }
ws_stream_tungstenite = { version = "0.10.0", features = ["tokio_io"] }
tower-service = { version = "0.3" }
tracing = "0.1.40"
tracing-subscriber = { version ="0.3.18", features = ["env-filter"] }
ws_stream_tungstenite = { version = "0.13", features = ["tokio_io"] }
tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.4", package = "tlsn-core" }
tlsn-verifier = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.4", package = "tlsn-verifier", features = [
"tracing",
] }
tlsn-core = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.6", package = "tlsn-core" }
tlsn-verifier = { git = "https://github.com/tlsnotary/tlsn.git", tag = "v0.1.0-alpha.6", package = "tlsn-verifier" }

View File

@@ -1,4 +1,4 @@
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.6.19/axum/src/extract/ws.rs
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.7.3/axum/src/extract/ws.rs
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
@@ -66,9 +66,7 @@
//! }
//! }
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! # let _: Router = app;
//! ```
//!
//! # Passing data and/or state to an `on_upgrade` callback
@@ -97,9 +95,7 @@
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .with_state(AppState { /* ... */ });
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! # let _: Router = app;
//! ```
//!
//! # Read and write concurrently
@@ -128,7 +124,6 @@
//! ```
//!
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
#![allow(unused)]
use self::rejection::*;
@@ -141,13 +136,8 @@ use async_tungstenite::{
},
WebSocketStream,
};
use axum::{
body::{self, Bytes},
extract::FromRequestParts,
response::Response,
Error,
};
use axum::{body::Bytes, extract::FromRequestParts, response::Response, Error};
use axum_core::body::Body;
use futures_util::{
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
@@ -157,7 +147,7 @@ use http::{
request::Parts,
Method, StatusCode,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
@@ -175,12 +165,12 @@ use tracing::error;
///
/// See the [module docs](self) for an example.
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: OnUpgrade,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
}
@@ -197,9 +187,33 @@ impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
}
impl<F> WebSocketUpgrade<F> {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
/// The target minimum size of the write buffer to reach before writing the data
/// to the underlying stream.
///
/// The default value is 128 KiB.
///
/// If set to `0` each message will be eagerly written to the underlying stream.
/// It is often more optimal to allow them to buffer a little, hence the default value.
///
/// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
pub fn write_buffer_size(mut self, size: usize) -> Self {
self.config.write_buffer_size = size;
self
}
/// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors.
///
/// The default value is unlimited.
///
/// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
/// when writes to the underlying stream are failing. So the **write buffer can not
/// fill up if you are not observing write errors even if not flushing**.
///
/// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
/// and probably a little more depending on error handling strategy.
pub fn max_write_buffer_size(mut self, max: usize) -> Self {
self.config.max_write_buffer_size = max;
self
}
@@ -249,9 +263,7 @@ impl<F> WebSocketUpgrade<F> {
/// // ...
/// })
/// }
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// # let _: Router = app;
/// ```
pub fn protocols<I>(mut self, protocols: I) -> Self
where
@@ -308,7 +320,7 @@ impl<F> WebSocketUpgrade<F> {
/// ```
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
where
C: OnFailedUpdgrade,
C: OnFailedUpgrade,
{
WebSocketUpgrade {
config: self.config,
@@ -322,12 +334,12 @@ impl<F> WebSocketUpgrade<F> {
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
#[must_use = "to setup the WebSocket connection, this response must be returned"]
#[must_use = "to set up the WebSocket connection, this response must be returned"]
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
C: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
F: OnFailedUpdgrade,
F: OnFailedUpgrade,
{
let on_upgrade = self.on_upgrade;
let config = self.config;
@@ -344,6 +356,8 @@ impl<F> WebSocketUpgrade<F> {
return;
}
};
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(
// NOTARY_MODIFICATION: Need to use TokioAdapter to wrap Upgraded which doesn't implement futures crate's AsyncRead and AsyncWrite
TokioAdapter::new(upgraded),
@@ -376,19 +390,19 @@ impl<F> WebSocketUpgrade<F> {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(body::boxed(body::Empty::new())).unwrap()
builder.body(Body::empty()).unwrap()
}
}
/// What to do when a connection upgrade fails.
///
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
pub trait OnFailedUpdgrade: Send + 'static {
pub trait OnFailedUpgrade: Send + 'static {
/// Call the callback.
fn call(self, error: Error);
}
impl<F> OnFailedUpdgrade for F
impl<F> OnFailedUpgrade for F
where
F: FnOnce(Error) + Send + 'static,
{
@@ -397,20 +411,20 @@ where
}
}
/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
///
/// It simply ignores the error.
#[non_exhaustive]
#[derive(Debug)]
pub struct DefaultOnFailedUpdgrade;
pub struct DefaultOnFailedUpgrade;
impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
impl OnFailedUpgrade for DefaultOnFailedUpgrade {
#[inline]
fn call(self, _error: Error) {}
}
#[async_trait]
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
where
S: Send + Sync,
{
@@ -441,7 +455,7 @@ where
let on_upgrade = parts
.extensions
.remove::<OnUpgrade>()
.remove::<hyper::upgrade::OnUpgrade>()
.ok_or(ConnectionNotUpgradable)?;
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
@@ -452,11 +466,12 @@ where
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
on_failed_upgrade: DefaultOnFailedUpdgrade,
on_failed_upgrade: DefaultOnFailedUpgrade,
})
}
}
/// NOTARY_MODIFICATION: Made this function public to be used in service.rs
pub fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
@@ -484,13 +499,13 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) ->
/// See [the module level documentation](self) for more details.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<TokioAdapter<Upgraded>>,
inner: WebSocketStream<TokioAdapter<TokioIo<hyper::upgrade::Upgraded>>>,
protocol: Option<HeaderValue>,
}
impl WebSocket {
/// Consume `self` and get the inner [`async_tungstenite::WebSocketStream`].
pub fn into_inner(self) -> WebSocketStream<TokioAdapter<Upgraded>> {
/// NOTARY_MODIFICATION: Consume `self` and get the inner [`async_tungstenite::WebSocketStream`].
pub fn into_inner(self) -> WebSocketStream<TokioAdapter<TokioIo<hyper::upgrade::Upgraded>>> {
self.inner
}

View File

@@ -1,8 +1,13 @@
use axum::{extract::State, response::IntoResponse, routing::get, Router};
use axum::{
extract::{Request, State},
response::IntoResponse,
routing::get,
Router,
};
use axum_websocket::{WebSocket, WebSocketUpgrade};
use eyre::eyre;
use http::Request;
use hyper::server::conn::Http;
use hyper::{body::Incoming, server::conn::http1};
use hyper_util::rt::TokioIo;
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
@@ -14,7 +19,7 @@ use tokio::{
net::TcpListener,
};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tower::MakeService;
use tower_service::Service;
use tracing::{debug, error, info};
use ws_stream_tungstenite::WsStream;
@@ -45,14 +50,13 @@ pub async fn run_server(
info!("Listening for TCP traffic at {}", verifier_address);
let protocol = Arc::new(Http::new());
let protocol = Arc::new(http1::Builder::new());
let router = Router::new()
.route("/verify", get(ws_handler))
.with_state(VerifierGlobals {
server_domain: server_domain.to_string(),
verification_session_id: verification_session_id.to_string(),
});
let mut app = router.into_make_service();
loop {
let stream = match listener.accept().await {
@@ -64,14 +68,19 @@ pub async fn run_server(
};
debug!("Received a prover's TCP connection");
let tower_service = router.clone();
let protocol = protocol.clone();
let service = MakeService::<_, Request<hyper::Body>>::make_service(&mut app, &stream);
tokio::spawn(async move {
info!("Accepted prover's TCP connection",);
// Reference: https://github.com/tokio-rs/axum/blob/5201798d4e4d4759c208ef83e30ce85820c07baa/examples/low-level-rustls/src/main.rs#L67-L80
let io = TokioIo::new(stream);
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
tower_service.clone().call(request)
});
// Serve different requests using the same hyper protocol and axum router
let _ = protocol
// Can unwrap because it's infallible
.serve_connection(stream, service.await.unwrap())
.serve_connection(io, hyper_service)
// use with_upgrades to upgrade connection to websocket for websocket clients
// and to extract tcp connection for tcp clients
.with_upgrades()