diff --git a/Cargo.toml b/Cargo.toml index c68fc387d..1292746fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,8 +88,8 @@ aes = { version = "0.8" } aes-gcm = { version = "0.9" } anyhow = { version = "1.0" } async-trait = { version = "0.1" } -async-tungstenite = { version = "0.25" } -axum = { version = "0.7" } +async-tungstenite = { version = "0.28.2" } +axum = { version = "0.8" } bcs = { version = "0.1" } bincode = { version = "1.3" } blake3 = { version = "1.5" } @@ -144,5 +144,5 @@ uuid = { version = "1.4" } web-time = { version = "0.2" } webpki = { version = "0.22" } webpki-roots = { version = "0.26" } -ws_stream_tungstenite = { version = "0.13" } +ws_stream_tungstenite = { version = "0.14" } zeroize = { version = "1.8" } diff --git a/crates/notary/server/Cargo.toml b/crates/notary/server/Cargo.toml index c3178f968..a2c04e6e2 100644 --- a/crates/notary/server/Cargo.toml +++ b/crates/notary/server/Cargo.toml @@ -19,17 +19,17 @@ tlsn-core = { workspace = true } tlsn-common = { workspace = true } tlsn-verifier = { workspace = true } -async-trait = { workspace = true } async-tungstenite = { workspace = true, features = ["tokio-native-tls"] } axum = { workspace = true, features = ["ws"] } -axum-core = { version = "0.4" } -axum-macros = { version = "0.4" } +axum-core = { version = "0.5" } +axum-macros = { version = "0.5" } base64 = { version = "0.21" } config = { version = "0.14", features = ["yaml"] } csv = { version = "1.3" } eyre = { version = "0.6" } futures-util = { workspace = true } http = { workspace = true } +http-body-util = { workspace = true } hyper = { workspace = true, features = ["client", "http1", "server"] } hyper-util = { workspace = true, features = ["full"] } k256 = { workspace = true } diff --git a/crates/notary/server/src/middleware.rs b/crates/notary/server/src/middleware.rs index 84f0b40a4..925827584 100644 --- a/crates/notary/server/src/middleware.rs +++ b/crates/notary/server/src/middleware.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use axum::http::{header, request::Parts}; use axum_core::extract::{FromRef, FromRequestParts}; use std::collections::HashMap; @@ -12,7 +11,6 @@ use crate::{ /// Auth middleware to prevent DOS pub struct AuthorizationMiddleware; -#[async_trait] impl FromRequestParts for AuthorizationMiddleware where NotaryGlobals: FromRef, diff --git a/crates/notary/server/src/service.rs b/crates/notary/server/src/service.rs index e58deed74..2ba50e3fb 100644 --- a/crates/notary/server/src/service.rs +++ b/crates/notary/server/src/service.rs @@ -2,7 +2,6 @@ pub mod axum_websocket; pub mod tcp; pub mod websocket; -use async_trait::async_trait; use axum::{ extract::{rejection::JsonRejection, FromRequestParts, Query, State}, http::{header, request::Parts, StatusCode}, @@ -43,7 +42,6 @@ pub enum ProtocolUpgrade { Ws(WebSocketUpgrade), } -#[async_trait] impl FromRequestParts for ProtocolUpgrade where S: Send + Sync, @@ -56,17 +54,17 @@ where let extractor = WebSocketUpgrade::from_request_parts(parts, state) .await .map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?; - return Ok(Self::Ws(extractor)); + Ok(Self::Ws(extractor)) // Extract tcp connection for tcp client } else if header_eq(&parts.headers, header::UPGRADE, "tcp") { let extractor = TcpUpgrade::from_request_parts(parts, state) .await .map_err(|err| NotaryServerError::BadProverRequest(err.to_string()))?; - return Ok(Self::Tcp(extractor)); + Ok(Self::Tcp(extractor)) } else { - return Err(NotaryServerError::BadProverRequest( + Err(NotaryServerError::BadProverRequest( "Upgrade header is not set for client".to_string(), - )); + )) } } } diff --git a/crates/notary/server/src/service/axum_websocket.rs b/crates/notary/server/src/service/axum_websocket.rs index 991391b6f..53d1c3aa9 100644 --- a/crates/notary/server/src/service/axum_websocket.rs +++ b/crates/notary/server/src/service/axum_websocket.rs @@ -1,4 +1,4 @@ -//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.7.3/axum/src/extract/ws.rs +//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.8.0/axum/src/extract/ws.rs //! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/) //! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use //! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html) @@ -41,12 +41,12 @@ //! ``` //! use axum::{ //! extract::ws::{WebSocketUpgrade, WebSocket}, -//! routing::get, +//! routing::any, //! response::{IntoResponse, Response}, //! Router, //! }; //! -//! let app = Router::new().route("/ws", get(handler)); +//! let app = Router::new().route("/ws", any(handler)); //! //! async fn handler(ws: WebSocketUpgrade) -> Response { //! ws.on_upgrade(handle_socket) @@ -76,7 +76,7 @@ //! use axum::{ //! extract::{ws::{WebSocketUpgrade, WebSocket}, State}, //! response::Response, -//! routing::get, +//! routing::any, //! Router, //! }; //! @@ -94,7 +94,7 @@ //! } //! //! let app = Router::new() -//! .route("/ws", get(handler)) +//! .route("/ws", any(handler)) //! .with_state(AppState { /* ... */ }); //! # let _: Router = app; //! ``` @@ -125,10 +125,11 @@ //! ``` //! //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split -#![allow(unused)] + +#![allow(unused)] // NOTARY_MODIFICATION use self::rejection::*; -use async_trait::async_trait; +/// NOTARY_MODIFICATION: async_tungstenite instead of tokio_tungstenite use async_tungstenite::{ tokio::TokioAdapter, tungstenite::{ @@ -137,6 +138,7 @@ use async_tungstenite::{ }, WebSocketStream, }; +/// NOTARY_MODIFICATION: axum use axum::{body::Bytes, extract::FromRequestParts, response::Response, Error}; use axum_core::body::Body; use futures_util::{ @@ -146,7 +148,7 @@ use futures_util::{ use http::{ header::{self, HeaderMap, HeaderName, HeaderValue}, request::Parts, - Method, StatusCode, + Method, StatusCode, Version, }; use hyper_util::rt::TokioIo; use sha1::{Digest, Sha1}; @@ -160,18 +162,21 @@ use tracing::error; /// Extractor for establishing WebSocket connections. /// -/// Note: This extractor requires the request method to be `GET` so it should -/// always be used with [`get`](crate::routing::get). Requests with other -/// methods will be rejected. +/// For HTTP/1.1 requests, this extractor requires the request method to be +/// `GET`; in later versions, `CONNECT` is used instead. +/// To support both, it should be used with [`any`](crate::routing::any). /// /// See the [module docs](self) for an example. +/// +/// [`MethodFilter`]: crate::routing::MethodFilter #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] pub struct WebSocketUpgrade { config: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the /// response. protocol: Option, - sec_websocket_key: HeaderValue, + /// `None` if HTTP/2+ WebSockets are used. + sec_websocket_key: Option, on_upgrade: hyper::upgrade::OnUpgrade, on_failed_upgrade: F, sec_websocket_protocol: Option, @@ -257,12 +262,12 @@ impl WebSocketUpgrade { /// ``` /// use axum::{ /// extract::ws::{WebSocketUpgrade, WebSocket}, - /// routing::get, + /// routing::any, /// response::{IntoResponse, Response}, /// Router, /// }; /// - /// let app = Router::new().route("/ws", get(handler)); + /// let app = Router::new().route("/ws", any(handler)); /// /// async fn handler(ws: WebSocketUpgrade) -> Response { /// ws.protocols(["graphql-ws", "graphql-transport-ws"]) @@ -358,6 +363,7 @@ impl WebSocketUpgrade { let upgraded = match on_upgrade.await { Ok(upgraded) => upgraded, Err(err) => { + // NOTARY_MODIFICATION: log error error!("Something wrong with on_upgrade: {:?}", err); on_failed_upgrade.call(Error::new(err)); return; @@ -380,25 +386,34 @@ impl WebSocketUpgrade { callback(socket).await; }); - #[allow(clippy::declare_interior_mutable_const)] - const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); - #[allow(clippy::declare_interior_mutable_const)] - const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + if let Some(sec_websocket_key) = &self.sec_websocket_key { + // If `sec_websocket_key` was `Some`, we are using HTTP/1.1. - let mut builder = Response::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(header::CONNECTION, UPGRADE) - .header(header::UPGRADE, WEBSOCKET) - .header( - header::SEC_WEBSOCKET_ACCEPT, - sign(self.sec_websocket_key.as_bytes()), - ); + #[allow(clippy::declare_interior_mutable_const)] + const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + #[allow(clippy::declare_interior_mutable_const)] + const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); - if let Some(protocol) = self.protocol { - builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + let mut builder = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, UPGRADE) + .header(header::UPGRADE, WEBSOCKET) + .header( + header::SEC_WEBSOCKET_ACCEPT, + sign(sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = self.protocol { + builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + } + + builder.body(Body::empty()).unwrap() + } else { + // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just + // respond with a 2XX with an empty body: + // . + Response::new(Body::empty()) } - - builder.body(Body::empty()).unwrap() } } @@ -431,7 +446,6 @@ impl OnFailedUpgrade for DefaultOnFailedUpgrade { fn call(self, _error: Error) {} } -#[async_trait] impl FromRequestParts for WebSocketUpgrade where S: Send + Sync, @@ -439,28 +453,51 @@ where type Rejection = WebSocketUpgradeRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - if parts.method != Method::GET { - return Err(MethodNotGet.into()); - } + let sec_websocket_key = if parts.version <= Version::HTTP_11 { + if parts.method != Method::GET { + return Err(MethodNotGet.into()); + } - if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { - return Err(InvalidConnectionHeader.into()); - } + if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { + return Err(InvalidConnectionHeader.into()); + } - if !header_eq(&parts.headers, header::UPGRADE, "websocket") { - return Err(InvalidUpgradeHeader.into()); - } + if !header_eq(&parts.headers, header::UPGRADE, "websocket") { + return Err(InvalidUpgradeHeader.into()); + } + + Some( + parts + .headers + .get(header::SEC_WEBSOCKET_KEY) + .ok_or(WebSocketKeyHeaderMissing)? + .clone(), + ) + } else { + if parts.method != Method::CONNECT { + return Err(MethodNotConnect.into()); + } + + // NOTARY_MODIFICATION: ignore http2 feature + + // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to + // begin with. + // #[cfg(feature = "http2")] + // if parts + // .extensions + // .get::() + // .map_or(true, |p| p.as_str() != "websocket") + // { + // return Err(InvalidProtocolPseudoheader.into()); + // } + + None + }; if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { return Err(InvalidWebSocketVersionHeader.into()); } - let sec_websocket_key = parts - .headers - .get(header::SEC_WEBSOCKET_KEY) - .ok_or(WebSocketKeyHeaderMissing)? - .clone(); - let on_upgrade = parts .extensions .remove::() @@ -507,6 +544,7 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> /// See [the module level documentation](self) for more details. #[derive(Debug)] pub struct WebSocket { + // NOTARY_MODIFICATION: TokioAdapter inner: WebSocketStream>>, protocol: Option, } @@ -533,11 +571,6 @@ impl WebSocket { .map_err(Error::new) } - /// Gracefully close this WebSocket. - pub async fn close(mut self) -> Result<(), Error> { - self.inner.close(None).await.map_err(Error::new) - } - /// Return the selected WebSocket subprotocol, if one has been chosen. pub fn protocol(&self) -> Option<&HeaderValue> { self.protocol.as_ref() @@ -584,17 +617,137 @@ impl Sink for WebSocket { } } +/// UTF-8 wrapper for [Bytes]. +/// +/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8. +/// The following NOTARY_MODIFICATION(s) are required because +/// `async_tungstenite` (v0.28.2) is using an older version of `tungstenite` +/// than `tokio_tungstenite` (v0.26.1). This older version of `tungstenite` +/// (v0.26.0) doesn't have `Utf8Bytes`. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Utf8Bytes(axum::extract::ws::Utf8Bytes); // NOTARY_MODIFICATION + +impl Utf8Bytes { + /// Creates from a static str. + #[inline] + pub const fn from_static(str: &'static str) -> Self { + Self(axum::extract::ws::Utf8Bytes::from_static(str)) // NOTARY_MODIFICATION + } + + /// Returns as a string slice. + #[inline] + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + fn into_tungstenite(self) -> axum::extract::ws::Utf8Bytes { + // NOTARY_MODIFICATION + self.0 + } +} + +impl std::ops::Deref for Utf8Bytes { + type Target = str; + + /// ``` + /// /// Example fn that takes a str slice + /// fn a(s: &str) {} + /// + /// let data = axum::extract::ws::Utf8Bytes::from_static("foo123"); + /// + /// // auto-deref as arg + /// a(&data); + /// + /// // deref to str methods + /// assert_eq!(data.len(), 6); + /// ``` + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl std::fmt::Display for Utf8Bytes { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl TryFrom for Utf8Bytes { + type Error = std::str::Utf8Error; + + #[inline] + fn try_from(bytes: Bytes) -> Result { + Ok(Self(bytes.try_into()?)) + } +} + +impl TryFrom> for Utf8Bytes { + type Error = std::str::Utf8Error; + + #[inline] + fn try_from(v: Vec) -> Result { + Ok(Self(v.try_into()?)) + } +} + +impl From for Utf8Bytes { + #[inline] + fn from(s: String) -> Self { + Self(s.into()) + } +} + +impl From<&str> for Utf8Bytes { + #[inline] + fn from(s: &str) -> Self { + Self(s.into()) + } +} + +impl From<&String> for Utf8Bytes { + #[inline] + fn from(s: &String) -> Self { + Self(s.into()) + } +} + +impl From for Bytes { + #[inline] + fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self { + bytes.into() + } +} + +impl PartialEq for Utf8Bytes +where + for<'a> &'a str: PartialEq, +{ + /// ``` + /// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123"); + /// assert_eq!(payload, "foo123"); + /// assert_eq!(payload, "foo123".to_string()); + /// assert_eq!(payload, &"foo123".to_string()); + /// assert_eq!(payload, std::borrow::Cow::from("foo123")); + /// ``` + #[inline] + fn eq(&self, other: &T) -> bool { + self.as_str() == *other + } +} + /// Status code used to indicate why an endpoint is closing the WebSocket /// connection. pub type CloseCode = u16; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] -pub struct CloseFrame<'t> { +pub struct CloseFrame { /// The reason as a code. pub code: CloseCode, /// The reason as text string. - pub reason: Cow<'t, str>, + pub reason: Utf8Bytes, } /// A WebSocket message. @@ -623,16 +776,16 @@ pub struct CloseFrame<'t> { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(String), + Text(Utf8Bytes), /// A binary WebSocket message - Binary(Vec), + Binary(Bytes), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes. /// /// Ping messages will be automatically responded to by the server, so you /// do not have to worry about dealing with them yourself. - Ping(Vec), + Ping(Bytes), /// A pong message with the specified payload /// /// The payload here must have a length less than 125 bytes. @@ -640,21 +793,44 @@ pub enum Message { /// Pong messages will be automatically sent to the client if a ping message /// is received, so you do not have to worry about constructing them /// yourself unless you want to implement a [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3). - Pong(Vec), + Pong(Bytes), /// A close message with the optional close frame. - Close(Option>), + /// + /// You may "uncleanly" close a WebSocket connection at any time + /// by simply dropping the [`WebSocket`]. + /// However, you may also use the graceful closing protocol, in which + /// 1. peer A sends a close frame, and does not send any further messages; + /// 2. peer B responds with a close frame, and does not send any further + /// messages; + /// 3. peer A processes the remaining messages sent by peer B, before + /// finally + /// 4. both peers close the connection. + /// + /// After sending a close frame, + /// you may still read messages, + /// but attempts to send another message will error. + /// After receiving a close frame, + /// axum will automatically respond with a close frame if necessary + /// (you do not have to deal with this yourself). + /// Since no further messages will be received, + /// you may either do nothing + /// or explicitly drop the connection. + Close(Option), } +/// The following NOTARY_MODIFICATION(s) are required because +/// `async_tungstenite` (v0.28.2) is using an older version of `tungstenite` +/// than `tokio_tungstenite` (v0.26.1). impl Message { fn into_tungstenite(self) -> ts::Message { match self { - Self::Text(text) => ts::Message::Text(text), - Self::Binary(binary) => ts::Message::Binary(binary), - Self::Ping(ping) => ts::Message::Ping(ping), - Self::Pong(pong) => ts::Message::Pong(pong), + Self::Text(text) => ts::Message::Text(text.into_tungstenite().to_string()), /* NOTARY_MODIFICATION */ + Self::Binary(binary) => ts::Message::Binary(binary.to_vec()), /* NOTARY_MODIFICATION */ + Self::Ping(ping) => ts::Message::Ping(ping.to_vec()), /* NOTARY_MODIFICATION */ + Self::Pong(pong) => ts::Message::Pong(pong.to_vec()), /* NOTARY_MODIFICATION */ Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame { code: ts::protocol::frame::coding::CloseCode::from(close.code), - reason: close.reason, + reason: Cow::Owned(close.reason.into_tungstenite().to_string()), /* NOTARY_MODIFICATION */ })), Self::Close(None) => ts::Message::Close(None), } @@ -662,13 +838,13 @@ impl Message { fn from_tungstenite(message: ts::Message) -> Option { match message { - ts::Message::Text(text) => Some(Self::Text(text)), - ts::Message::Binary(binary) => Some(Self::Binary(binary)), - ts::Message::Ping(ping) => Some(Self::Ping(ping)), - ts::Message::Pong(pong) => Some(Self::Pong(pong)), + ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text.into()))), /* NOTARY_MODIFICATION */ + ts::Message::Binary(binary) => Some(Self::Binary(binary.into())), /* NOTARY_MODIFICATION */ + ts::Message::Ping(ping) => Some(Self::Ping(ping.into())), /* NOTARY_MODIFICATION */ + ts::Message::Pong(pong) => Some(Self::Pong(pong.into())), /* NOTARY_MODIFICATION */ ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame { code: close.code.into(), - reason: close.reason, + reason: Utf8Bytes(close.reason.to_string().into()), /* NOTARY_MODIFICATION */ }))), ts::Message::Close(None) => Some(Self::Close(None)), // we can ignore `Frame` frames as recommended by the tungstenite maintainers @@ -678,24 +854,24 @@ impl Message { } /// Consume the WebSocket and return it as binary data. - pub fn into_data(self) -> Vec { + pub fn into_data(self) -> Bytes { match self { - Self::Text(string) => string.into_bytes(), + Self::Text(string) => Bytes::from(string), Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data, - Self::Close(None) => Vec::new(), - Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), + Self::Close(None) => Bytes::new(), + Self::Close(Some(frame)) => Bytes::from(frame.reason), } } - /// Attempt to consume the WebSocket message and convert it to a String. - pub fn into_text(self) -> Result { + /// Attempt to consume the WebSocket message and convert it to a Utf8Bytes. + pub fn into_text(self) -> Result { match self { Self::Text(string) => Ok(string), - Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data) - .map_err(|err| err.utf8_error()) - .map_err(Error::new)?), - Self::Close(None) => Ok(String::new()), - Self::Close(Some(frame)) => Ok(frame.reason.into_owned()), + Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => { + Ok(Utf8Bytes::try_from(data).map_err(Error::new)?) + } + Self::Close(None) => Ok(Utf8Bytes::default()), + Self::Close(Some(frame)) => Ok(frame.reason), } } @@ -703,7 +879,7 @@ impl Message { /// this will try to convert binary data to utf8. pub fn to_text(&self) -> Result<&str, Error> { match *self { - Self::Text(ref string) => Ok(string), + Self::Text(ref string) => Ok(string.as_str()), Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => { Ok(std::str::from_utf8(data).map_err(Error::new)?) } @@ -711,11 +887,27 @@ impl Message { Self::Close(Some(ref frame)) => Ok(&frame.reason), } } + + /// Create a new text WebSocket message from a stringable. + pub fn text(string: S) -> Message + where + S: Into, + { + Message::Text(string.into()) + } + + /// Create a new binary WebSocket message by converting to `Bytes`. + pub fn binary(bin: B) -> Message + where + B: Into, + { + Message::Binary(bin.into()) + } } impl From for Message { fn from(string: String) -> Self { - Message::Text(string) + Message::Text(string.into()) } } @@ -727,19 +919,19 @@ impl<'s> From<&'s str> for Message { impl<'b> From<&'b [u8]> for Message { fn from(data: &'b [u8]) -> Self { - Message::Binary(data.into()) + Message::Binary(Bytes::copy_from_slice(data)) } } impl From> for Message { fn from(data: Vec) -> Self { - Message::Binary(data) + Message::Binary(data.into()) } } impl From for Vec { fn from(msg: Message) -> Self { - msg.into_data() + msg.into_data().to_vec() } } @@ -767,6 +959,13 @@ pub mod rejection { pub struct MethodNotGet; } + define_rejection! { + #[status = METHOD_NOT_ALLOWED] + #[body = "Request method must be `CONNECT`"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct MethodNotConnect; + } + define_rejection! { #[status = BAD_REQUEST] #[body = "Connection header did not include 'upgrade'"] @@ -781,6 +980,13 @@ pub mod rejection { pub struct InvalidUpgradeHeader; } + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`:protocol` pseudo-header did not include 'websocket'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidProtocolPseudoheader; + } + define_rejection! { #[status = BAD_REQUEST] #[body = "`Sec-WebSocket-Version` header did not include '13'"] @@ -816,8 +1022,10 @@ pub mod rejection { /// extractor can fail. pub enum WebSocketUpgradeRejection { MethodNotGet, + MethodNotConnect, InvalidConnectionHeader, InvalidUpgradeHeader, + InvalidProtocolPseudoheader, InvalidWebSocketVersionHeader, WebSocketKeyHeaderMissing, ConnectionNotUpgradable, @@ -843,9 +1051,10 @@ pub mod close_code { pub const PROTOCOL: u16 = 1002; /// Indicates that an endpoint is terminating the connection because it has - /// received a type of data it cannot accept (e.g., an endpoint that - /// understands only text data MAY send this if it receives a binary - /// message). + /// received a type of data that it cannot accept. + /// + /// For example, an endpoint MAY send this if it understands only text data, + /// but receives a binary message. pub const UNSUPPORTED: u16 = 1003; /// Indicates that no status code was included in a closing frame. @@ -856,14 +1065,18 @@ pub mod close_code { /// Indicates that an endpoint is terminating the connection because it has /// received data within a message that was not consistent with the type - /// of the message (e.g., non-UTF-8 RFC3629 data within a text message). + /// of the message. + /// + /// For example, an endpoint received non-UTF-8 RFC3629 data within a text + /// message. pub const INVALID: u16 = 1007; /// Indicates that an endpoint is terminating the connection because it has - /// received a message that violates its policy. This is a generic - /// status code that can be returned when there is no other more - /// suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a - /// need to hide specific details about the policy. + /// received a message that violates its policy. + /// + /// This is a generic status code that can be returned when there is + /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if + /// there is a need to hide specific details about the policy. pub const POLICY: u16 = 1008; /// Indicates that an endpoint is terminating the connection because it has @@ -871,12 +1084,15 @@ pub mod close_code { pub const SIZE: u16 = 1009; /// Indicates that an endpoint (client) is terminating the connection - /// because it has expected the server to negotiate one or more - /// extension, but the server didn't return them in the response message - /// of the WebSocket handshake. The list of extensions that are needed - /// should be given as the reason for closing. Note that this status - /// code is not used by the server, because it can fail the WebSocket - /// handshake instead. + /// because the server did not respond to extension negotiation + /// correctly. + /// + /// Specifically, the client has expected the server to negotiate one or + /// more extension(s), but the server didn't return them in the response + /// message of the WebSocket handshake. The list of extensions that are + /// needed should be given as the reason for closing. Note that this + /// status code is not used by the server, because it can fail the + /// WebSocket handshake instead. pub const EXTENSION: u16 = 1010; /// Indicates that a server is terminating the connection because it @@ -896,14 +1112,22 @@ pub mod close_code { #[cfg(test)] mod tests { use super::*; - use axum::{body::Body, routing::get, Router}; + use async_tungstenite::tungstenite; // NOTARY_MODIFICATION + use axum::routing::any; + use axum::{body::Body, routing::get, Router}; // NOTARY_MODIFICATION use http::{Request, Version}; - // NOTARY_MODIFICATION: use tower_util instead of tower to make clippy happy + use http_body_util::BodyExt as _; + use hyper_util::rt::TokioExecutor; + use std::future::ready; + use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, + }; use tower_util::ServiceExt; - #[tokio::test] + #[tokio::test] // NOTARY_MODIFICATION async fn rejects_http_1_0_requests() { - let svc = get(|ws: Result| { + let svc = any(|ws: Result| { let rejection = ws.unwrap_err(); assert!(matches!( rejection, @@ -932,7 +1156,7 @@ mod tests { async fn handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(|_| async {}) } - let _: Router = Router::new().route("/", get(handler)); + let _: Router = Router::new().route("/", any(handler)); } #[allow(dead_code)] @@ -941,6 +1165,8 @@ mod tests { ws.on_failed_upgrade(|_error: Error| println!("oops!")) .on_upgrade(|_| async {}) } - let _: Router = Router::new().route("/", get(handler)); + let _: Router = Router::new().route("/", any(handler)); } + + // NOTARY_MODIFICATION: removed integration test } diff --git a/crates/notary/server/src/service/tcp.rs b/crates/notary/server/src/service/tcp.rs index e2b433de6..0aaf4e523 100644 --- a/crates/notary/server/src/service/tcp.rs +++ b/crates/notary/server/src/service/tcp.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use axum::{ extract::FromRequestParts, http::{header, request::Parts, HeaderValue, StatusCode}, @@ -22,7 +21,6 @@ pub struct TcpUpgrade { pub on_upgrade: OnUpgrade, } -#[async_trait] impl FromRequestParts for TcpUpgrade where S: Send + Sync, diff --git a/crates/server-fixture/server/src/lib.rs b/crates/server-fixture/server/src/lib.rs index d33a5902d..8c4a63b3a 100644 --- a/crates/server-fixture/server/src/lib.rs +++ b/crates/server-fixture/server/src/lib.rs @@ -26,7 +26,7 @@ use serde_json::Value; use tokio_util::compat::FuturesAsyncReadCompatExt; use tower_service::Service; -use axum::{async_trait, extract::FromRequest}; +use axum::extract::FromRequest; use hyper::header; use tlsn_server_fixture_certs::*; @@ -143,10 +143,9 @@ async fn html( struct AuthenticatedUser; -#[async_trait] impl FromRequest for AuthenticatedUser where - B: Send, + B: Send + Sync, { type Rejection = (StatusCode, &'static str);