chore: update axum to v0.8 (#681)

chore: update `axum` to v0.8

Co-authored-by: yuroitaki <25913766+yuroitaki@users.noreply.github.com>
This commit is contained in:
Hendrik Eeckhaut
2025-01-08 09:24:01 +01:00
committed by GitHub
parent c03418a642
commit 65299d7def
7 changed files with 342 additions and 123 deletions

View File

@@ -88,8 +88,8 @@ aes = { version = "0.8" }
aes-gcm = { version = "0.9" }
anyhow = { version = "1.0" }
async-trait = { version = "0.1" }
async-tungstenite = { version = "0.25" }
axum = { version = "0.7" }
async-tungstenite = { version = "0.28.2" }
axum = { version = "0.8" }
bcs = { version = "0.1" }
bincode = { version = "1.3" }
blake3 = { version = "1.5" }
@@ -144,5 +144,5 @@ uuid = { version = "1.4" }
web-time = { version = "0.2" }
webpki = { version = "0.22" }
webpki-roots = { version = "0.26" }
ws_stream_tungstenite = { version = "0.13" }
ws_stream_tungstenite = { version = "0.14" }
zeroize = { version = "1.8" }

View File

@@ -19,17 +19,17 @@ tlsn-core = { workspace = true }
tlsn-common = { workspace = true }
tlsn-verifier = { workspace = true }
async-trait = { workspace = true }
async-tungstenite = { workspace = true, features = ["tokio-native-tls"] }
axum = { workspace = true, features = ["ws"] }
axum-core = { version = "0.4" }
axum-macros = { version = "0.4" }
axum-core = { version = "0.5" }
axum-macros = { version = "0.5" }
base64 = { version = "0.21" }
config = { version = "0.14", features = ["yaml"] }
csv = { version = "1.3" }
eyre = { version = "0.6" }
futures-util = { workspace = true }
http = { workspace = true }
http-body-util = { workspace = true }
hyper = { workspace = true, features = ["client", "http1", "server"] }
hyper-util = { workspace = true, features = ["full"] }
k256 = { workspace = true }

View File

@@ -1,4 +1,3 @@
use async_trait::async_trait;
use axum::http::{header, request::Parts};
use axum_core::extract::{FromRef, FromRequestParts};
use std::collections::HashMap;
@@ -12,7 +11,6 @@ use crate::{
/// Auth middleware to prevent DOS
pub struct AuthorizationMiddleware;
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
NotaryGlobals: FromRef<S>,

View File

@@ -2,7 +2,6 @@ pub mod axum_websocket;
pub mod tcp;
pub mod websocket;
use async_trait::async_trait;
use axum::{
extract::{rejection::JsonRejection, FromRequestParts, Query, State},
http::{header, request::Parts, StatusCode},
@@ -43,7 +42,6 @@ pub enum ProtocolUpgrade {
Ws(WebSocketUpgrade),
}
#[async_trait]
impl<S> FromRequestParts<S> for ProtocolUpgrade
where
S: Send + Sync,
@@ -56,17 +54,17 @@ where
let extractor = WebSocketUpgrade::from_request_parts(parts, state)
.await
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
return Ok(Self::Ws(extractor));
Ok(Self::Ws(extractor))
// Extract tcp connection for tcp client
} else if header_eq(&parts.headers, header::UPGRADE, "tcp") {
let extractor = TcpUpgrade::from_request_parts(parts, state)
.await
.map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?;
return Ok(Self::Tcp(extractor));
Ok(Self::Tcp(extractor))
} else {
return Err(NotaryServerError::BadProverRequest(
Err(NotaryServerError::BadProverRequest(
"Upgrade header is not set for client".to_string(),
));
))
}
}
}

View File

@@ -1,4 +1,4 @@
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.7.3/axum/src/extract/ws.rs
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.8.0/axum/src/extract/ws.rs
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
@@ -41,12 +41,12 @@
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get,
//! routing::any,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get(handler));
//! let app = Router::new().route("/ws", any(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
@@ -76,7 +76,7 @@
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get,
//! routing::any,
//! Router,
//! };
//!
@@ -94,7 +94,7 @@
//! }
//!
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .route("/ws", any(handler))
//! .with_state(AppState { /* ... */ });
//! # let _: Router = app;
//! ```
@@ -125,10 +125,11 @@
//! ```
//!
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
#![allow(unused)]
#![allow(unused)] // NOTARY_MODIFICATION
use self::rejection::*;
use async_trait::async_trait;
/// NOTARY_MODIFICATION: async_tungstenite instead of tokio_tungstenite
use async_tungstenite::{
tokio::TokioAdapter,
tungstenite::{
@@ -137,6 +138,7 @@ use async_tungstenite::{
},
WebSocketStream,
};
/// NOTARY_MODIFICATION: axum
use axum::{body::Bytes, extract::FromRequestParts, response::Response, Error};
use axum_core::body::Body;
use futures_util::{
@@ -146,7 +148,7 @@ use futures_util::{
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
Method, StatusCode, Version,
};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
@@ -160,18 +162,21 @@ use tracing::error;
/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other
/// methods will be rejected.
/// For HTTP/1.1 requests, this extractor requires the request method to be
/// `GET`; in later versions, `CONNECT` is used instead.
/// To support both, it should be used with [`any`](crate::routing::any).
///
/// See the [module docs](self) for an example.
///
/// [`MethodFilter`]: crate::routing::MethodFilter
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the
/// response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
/// `None` if HTTP/2+ WebSockets are used.
sec_websocket_key: Option<HeaderValue>,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
@@ -257,12 +262,12 @@ impl<F> WebSocketUpgrade<F> {
/// ```
/// use axum::{
/// extract::ws::{WebSocketUpgrade, WebSocket},
/// routing::get,
/// routing::any,
/// response::{IntoResponse, Response},
/// Router,
/// };
///
/// let app = Router::new().route("/ws", get(handler));
/// let app = Router::new().route("/ws", any(handler));
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
@@ -358,6 +363,7 @@ impl<F> WebSocketUpgrade<F> {
let upgraded = match on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
// NOTARY_MODIFICATION: log error
error!("Something wrong with on_upgrade: {:?}", err);
on_failed_upgrade.call(Error::new(err));
return;
@@ -380,25 +386,34 @@ impl<F> WebSocketUpgrade<F> {
callback(socket).await;
});
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
if let Some(sec_websocket_key) = &self.sec_websocket_key {
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(Body::empty()).unwrap()
} else {
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just
// respond with a 2XX with an empty body:
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
Response::new(Body::empty())
}
builder.body(Body::empty()).unwrap()
}
}
@@ -431,7 +446,6 @@ impl OnFailedUpgrade for DefaultOnFailedUpgrade {
fn call(self, _error: Error) {}
}
#[async_trait]
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
where
S: Send + Sync,
@@ -439,28 +453,51 @@ where
type Rejection = WebSocketUpgradeRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
let sec_websocket_key = if parts.version <= Version::HTTP_11 {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
Some(
parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone(),
)
} else {
if parts.method != Method::CONNECT {
return Err(MethodNotConnect.into());
}
// NOTARY_MODIFICATION: ignore http2 feature
// if this feature flag is disabled, we wont be receiving an HTTP/2 request to
// begin with.
// #[cfg(feature = "http2")]
// if parts
// .extensions
// .get::<hyper::ext::Protocol>()
// .map_or(true, |p| p.as_str() != "websocket")
// {
// return Err(InvalidProtocolPseudoheader.into());
// }
None
};
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}
let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();
let on_upgrade = parts
.extensions
.remove::<hyper::upgrade::OnUpgrade>()
@@ -507,6 +544,7 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) ->
/// See [the module level documentation](self) for more details.
#[derive(Debug)]
pub struct WebSocket {
// NOTARY_MODIFICATION: TokioAdapter
inner: WebSocketStream<TokioAdapter<TokioIo<hyper::upgrade::Upgraded>>>,
protocol: Option<HeaderValue>,
}
@@ -533,11 +571,6 @@ impl WebSocket {
.map_err(Error::new)
}
/// Gracefully close this WebSocket.
pub async fn close(mut self) -> Result<(), Error> {
self.inner.close(None).await.map_err(Error::new)
}
/// Return the selected WebSocket subprotocol, if one has been chosen.
pub fn protocol(&self) -> Option<&HeaderValue> {
self.protocol.as_ref()
@@ -584,17 +617,137 @@ impl Sink<Message> for WebSocket {
}
}
/// UTF-8 wrapper for [Bytes].
///
/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
/// The following NOTARY_MODIFICATION(s) are required because
/// `async_tungstenite` (v0.28.2) is using an older version of `tungstenite`
/// than `tokio_tungstenite` (v0.26.1). This older version of `tungstenite`
/// (v0.26.0) doesn't have `Utf8Bytes`.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Utf8Bytes(axum::extract::ws::Utf8Bytes); // NOTARY_MODIFICATION
impl Utf8Bytes {
/// Creates from a static str.
#[inline]
pub const fn from_static(str: &'static str) -> Self {
Self(axum::extract::ws::Utf8Bytes::from_static(str)) // NOTARY_MODIFICATION
}
/// Returns as a string slice.
#[inline]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
fn into_tungstenite(self) -> axum::extract::ws::Utf8Bytes {
// NOTARY_MODIFICATION
self.0
}
}
impl std::ops::Deref for Utf8Bytes {
type Target = str;
/// ```
/// /// Example fn that takes a str slice
/// fn a(s: &str) {}
///
/// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
///
/// // auto-deref as arg
/// a(&data);
///
/// // deref to str methods
/// assert_eq!(data.len(), 6);
/// ```
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl std::fmt::Display for Utf8Bytes {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl TryFrom<Bytes> for Utf8Bytes {
type Error = std::str::Utf8Error;
#[inline]
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
Ok(Self(bytes.try_into()?))
}
}
impl TryFrom<Vec<u8>> for Utf8Bytes {
type Error = std::str::Utf8Error;
#[inline]
fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(v.try_into()?))
}
}
impl From<String> for Utf8Bytes {
#[inline]
fn from(s: String) -> Self {
Self(s.into())
}
}
impl From<&str> for Utf8Bytes {
#[inline]
fn from(s: &str) -> Self {
Self(s.into())
}
}
impl From<&String> for Utf8Bytes {
#[inline]
fn from(s: &String) -> Self {
Self(s.into())
}
}
impl From<Utf8Bytes> for Bytes {
#[inline]
fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
bytes.into()
}
}
impl<T> PartialEq<T> for Utf8Bytes
where
for<'a> &'a str: PartialEq<T>,
{
/// ```
/// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
/// assert_eq!(payload, "foo123");
/// assert_eq!(payload, "foo123".to_string());
/// assert_eq!(payload, &"foo123".to_string());
/// assert_eq!(payload, std::borrow::Cow::from("foo123"));
/// ```
#[inline]
fn eq(&self, other: &T) -> bool {
self.as_str() == *other
}
}
/// Status code used to indicate why an endpoint is closing the WebSocket
/// connection.
pub type CloseCode = u16;
/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CloseFrame<'t> {
pub struct CloseFrame {
/// The reason as a code.
pub code: CloseCode,
/// The reason as text string.
pub reason: Cow<'t, str>,
pub reason: Utf8Bytes,
}
/// A WebSocket message.
@@ -623,16 +776,16 @@ pub struct CloseFrame<'t> {
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
/// A text WebSocket message
Text(String),
Text(Utf8Bytes),
/// A binary WebSocket message
Binary(Vec<u8>),
Binary(Bytes),
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Ping messages will be automatically responded to by the server, so you
/// do not have to worry about dealing with them yourself.
Ping(Vec<u8>),
Ping(Bytes),
/// A pong message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
@@ -640,21 +793,44 @@ pub enum Message {
/// Pong messages will be automatically sent to the client if a ping message
/// is received, so you do not have to worry about constructing them
/// yourself unless you want to implement a [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
Pong(Vec<u8>),
Pong(Bytes),
/// A close message with the optional close frame.
Close(Option<CloseFrame<'static>>),
///
/// You may "uncleanly" close a WebSocket connection at any time
/// by simply dropping the [`WebSocket`].
/// However, you may also use the graceful closing protocol, in which
/// 1. peer A sends a close frame, and does not send any further messages;
/// 2. peer B responds with a close frame, and does not send any further
/// messages;
/// 3. peer A processes the remaining messages sent by peer B, before
/// finally
/// 4. both peers close the connection.
///
/// After sending a close frame,
/// you may still read messages,
/// but attempts to send another message will error.
/// After receiving a close frame,
/// axum will automatically respond with a close frame if necessary
/// (you do not have to deal with this yourself).
/// Since no further messages will be received,
/// you may either do nothing
/// or explicitly drop the connection.
Close(Option<CloseFrame>),
}
/// The following NOTARY_MODIFICATION(s) are required because
/// `async_tungstenite` (v0.28.2) is using an older version of `tungstenite`
/// than `tokio_tungstenite` (v0.26.1).
impl Message {
fn into_tungstenite(self) -> ts::Message {
match self {
Self::Text(text) => ts::Message::Text(text),
Self::Binary(binary) => ts::Message::Binary(binary),
Self::Ping(ping) => ts::Message::Ping(ping),
Self::Pong(pong) => ts::Message::Pong(pong),
Self::Text(text) => ts::Message::Text(text.into_tungstenite().to_string()), /* NOTARY_MODIFICATION */
Self::Binary(binary) => ts::Message::Binary(binary.to_vec()), /* NOTARY_MODIFICATION */
Self::Ping(ping) => ts::Message::Ping(ping.to_vec()), /* NOTARY_MODIFICATION */
Self::Pong(pong) => ts::Message::Pong(pong.to_vec()), /* NOTARY_MODIFICATION */
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
code: ts::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
reason: Cow::Owned(close.reason.into_tungstenite().to_string()), /* NOTARY_MODIFICATION */
})),
Self::Close(None) => ts::Message::Close(None),
}
@@ -662,13 +838,13 @@ impl Message {
fn from_tungstenite(message: ts::Message) -> Option<Self> {
match message {
ts::Message::Text(text) => Some(Self::Text(text)),
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text.into()))), /* NOTARY_MODIFICATION */
ts::Message::Binary(binary) => Some(Self::Binary(binary.into())), /* NOTARY_MODIFICATION */
ts::Message::Ping(ping) => Some(Self::Ping(ping.into())), /* NOTARY_MODIFICATION */
ts::Message::Pong(pong) => Some(Self::Pong(pong.into())), /* NOTARY_MODIFICATION */
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
reason: Utf8Bytes(close.reason.to_string().into()), /* NOTARY_MODIFICATION */
}))),
ts::Message::Close(None) => Some(Self::Close(None)),
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
@@ -678,24 +854,24 @@ impl Message {
}
/// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> {
pub fn into_data(self) -> Bytes {
match self {
Self::Text(string) => string.into_bytes(),
Self::Text(string) => Bytes::from(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
Self::Close(None) => Vec::new(),
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
Self::Close(None) => Bytes::new(),
Self::Close(Some(frame)) => Bytes::from(frame.reason),
}
}
/// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String, Error> {
/// Attempt to consume the WebSocket message and convert it to a Utf8Bytes.
pub fn into_text(self) -> Result<Utf8Bytes, Error> {
match self {
Self::Text(string) => Ok(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
.map_err(|err| err.utf8_error())
.map_err(Error::new)?),
Self::Close(None) => Ok(String::new()),
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(Utf8Bytes::default()),
Self::Close(Some(frame)) => Ok(frame.reason),
}
}
@@ -703,7 +879,7 @@ impl Message {
/// this will try to convert binary data to utf8.
pub fn to_text(&self) -> Result<&str, Error> {
match *self {
Self::Text(ref string) => Ok(string),
Self::Text(ref string) => Ok(string.as_str()),
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
Ok(std::str::from_utf8(data).map_err(Error::new)?)
}
@@ -711,11 +887,27 @@ impl Message {
Self::Close(Some(ref frame)) => Ok(&frame.reason),
}
}
/// Create a new text WebSocket message from a stringable.
pub fn text<S>(string: S) -> Message
where
S: Into<Utf8Bytes>,
{
Message::Text(string.into())
}
/// Create a new binary WebSocket message by converting to `Bytes`.
pub fn binary<B>(bin: B) -> Message
where
B: Into<Bytes>,
{
Message::Binary(bin.into())
}
}
impl From<String> for Message {
fn from(string: String) -> Self {
Message::Text(string)
Message::Text(string.into())
}
}
@@ -727,19 +919,19 @@ impl<'s> From<&'s str> for Message {
impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Self {
Message::Binary(data.into())
Message::Binary(Bytes::copy_from_slice(data))
}
}
impl From<Vec<u8>> for Message {
fn from(data: Vec<u8>) -> Self {
Message::Binary(data)
Message::Binary(data.into())
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_data()
msg.into_data().to_vec()
}
}
@@ -767,6 +959,13 @@ pub mod rejection {
pub struct MethodNotGet;
}
define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `CONNECT`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotConnect;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
@@ -781,6 +980,13 @@ pub mod rejection {
pub struct InvalidUpgradeHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`:protocol` pseudo-header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidProtocolPseudoheader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
@@ -816,8 +1022,10 @@ pub mod rejection {
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
MethodNotConnect,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidProtocolPseudoheader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
ConnectionNotUpgradable,
@@ -843,9 +1051,10 @@ pub mod close_code {
pub const PROTOCOL: u16 = 1002;
/// Indicates that an endpoint is terminating the connection because it has
/// received a type of data it cannot accept (e.g., an endpoint that
/// understands only text data MAY send this if it receives a binary
/// message).
/// received a type of data that it cannot accept.
///
/// For example, an endpoint MAY send this if it understands only text data,
/// but receives a binary message.
pub const UNSUPPORTED: u16 = 1003;
/// Indicates that no status code was included in a closing frame.
@@ -856,14 +1065,18 @@ pub mod close_code {
/// Indicates that an endpoint is terminating the connection because it has
/// received data within a message that was not consistent with the type
/// of the message (e.g., non-UTF-8 RFC3629 data within a text message).
/// of the message.
///
/// For example, an endpoint received non-UTF-8 RFC3629 data within a text
/// message.
pub const INVALID: u16 = 1007;
/// Indicates that an endpoint is terminating the connection because it has
/// received a message that violates its policy. This is a generic
/// status code that can be returned when there is no other more
/// suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a
/// need to hide specific details about the policy.
/// received a message that violates its policy.
///
/// This is a generic status code that can be returned when there is
/// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if
/// there is a need to hide specific details about the policy.
pub const POLICY: u16 = 1008;
/// Indicates that an endpoint is terminating the connection because it has
@@ -871,12 +1084,15 @@ pub mod close_code {
pub const SIZE: u16 = 1009;
/// Indicates that an endpoint (client) is terminating the connection
/// because it has expected the server to negotiate one or more
/// extension, but the server didn't return them in the response message
/// of the WebSocket handshake. The list of extensions that are needed
/// should be given as the reason for closing. Note that this status
/// code is not used by the server, because it can fail the WebSocket
/// handshake instead.
/// because the server did not respond to extension negotiation
/// correctly.
///
/// Specifically, the client has expected the server to negotiate one or
/// more extension(s), but the server didn't return them in the response
/// message of the WebSocket handshake. The list of extensions that are
/// needed should be given as the reason for closing. Note that this
/// status code is not used by the server, because it can fail the
/// WebSocket handshake instead.
pub const EXTENSION: u16 = 1010;
/// Indicates that a server is terminating the connection because it
@@ -896,14 +1112,22 @@ pub mod close_code {
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, routing::get, Router};
use async_tungstenite::tungstenite; // NOTARY_MODIFICATION
use axum::routing::any;
use axum::{body::Body, routing::get, Router}; // NOTARY_MODIFICATION
use http::{Request, Version};
// NOTARY_MODIFICATION: use tower_util instead of tower to make clippy happy
use http_body_util::BodyExt as _;
use hyper_util::rt::TokioExecutor;
use std::future::ready;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use tower_util::ServiceExt;
#[tokio::test]
#[tokio::test] // NOTARY_MODIFICATION
async fn rejects_http_1_0_requests() {
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let rejection = ws.unwrap_err();
assert!(matches!(
rejection,
@@ -932,7 +1156,7 @@ mod tests {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
let _: Router = Router::new().route("/", any(handler));
}
#[allow(dead_code)]
@@ -941,6 +1165,8 @@ mod tests {
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
let _: Router = Router::new().route("/", any(handler));
}
// NOTARY_MODIFICATION: removed integration test
}

View File

@@ -1,4 +1,3 @@
use async_trait::async_trait;
use axum::{
extract::FromRequestParts,
http::{header, request::Parts, HeaderValue, StatusCode},
@@ -22,7 +21,6 @@ pub struct TcpUpgrade {
pub on_upgrade: OnUpgrade,
}
#[async_trait]
impl<S> FromRequestParts<S> for TcpUpgrade
where
S: Send + Sync,

View File

@@ -26,7 +26,7 @@ use serde_json::Value;
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tower_service::Service;
use axum::{async_trait, extract::FromRequest};
use axum::extract::FromRequest;
use hyper::header;
use tlsn_server_fixture_certs::*;
@@ -143,10 +143,9 @@ async fn html(
struct AuthenticatedUser;
#[async_trait]
impl<B> FromRequest<B> for AuthenticatedUser
where
B: Send,
B: Send + Sync,
{
type Rejection = (StatusCode, &'static str);