@@ -1,4 +1,4 @@
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.7.3 /axum/src/extract/ws.rs
//! The following code is adapted from https://github.com/tokio-rs/axum/blob/axum-v0.8.0 /axum/src/extract/ws.rs
//! where we swapped out tokio_tungstenite (https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/)
//! with async_tungstenite (https://docs.rs/async-tungstenite/latest/async_tungstenite/) so that we can use
//! ws_stream_tungstenite (https://docs.rs/ws_stream_tungstenite/latest/ws_stream_tungstenite/index.html)
@@ -41,12 +41,12 @@
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get ,
//! routing::any ,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get (handler));
//! let app = Router::new().route("/ws", any (handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
@@ -76,7 +76,7 @@
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get ,
//! routing::any ,
//! Router,
//! };
//!
@@ -94,7 +94,7 @@
//! }
//!
//! let app = Router::new()
//! .route("/ws", get (handler))
//! .route("/ws", any (handler))
//! .with_state(AppState { /* ... */ });
//! # let _: Router = app;
//! ```
@@ -125,10 +125,11 @@
//! ```
//!
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
#![ allow(unused) ]
#![ allow(unused) ] // NOTARY_MODIFICATION
use self ::rejection ::* ;
use async_trait ::async_trait ;
/// NOTARY_MODIFICATION: async_tungstenite instead of tokio_tungstenite
use async_tungstenite ::{
tokio ::TokioAdapter ,
tungstenite ::{
@@ -137,6 +138,7 @@ use async_tungstenite::{
} ,
WebSocketStream ,
} ;
/// NOTARY_MODIFICATION: axum
use axum ::{ body ::Bytes , extract ::FromRequestParts , response ::Response , Error } ;
use axum_core ::body ::Body ;
use futures_util ::{
@@ -146,7 +148,7 @@ use futures_util::{
use http ::{
header ::{ self , HeaderMap , HeaderName , HeaderValue } ,
request ::Parts ,
Method , StatusCode ,
Method , StatusCode , Version ,
} ;
use hyper_util ::rt ::TokioIo ;
use sha1 ::{ Digest , Sha1 } ;
@@ -160,18 +162,21 @@ use tracing::error;
/// Extractor for establishing WebSocket connections.
///
/// Note: T his extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other
/// methods will be rejected .
/// For HTTP/1.1 requests, t his extractor requires the request method to be
/// `GET`; in later versions, `CONNECT` is used instead.
/// To support both, it should be used with [`any`](crate::routing::any) .
///
/// See the [module docs](self) for an example.
///
/// [`MethodFilter`]: crate::routing::MethodFilter
#[ cfg_attr(docsrs, doc(cfg(feature = " ws " ))) ]
pub struct WebSocketUpgrade < F = DefaultOnFailedUpgrade > {
config : WebSocketConfig ,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the
/// response.
protocol : Option < HeaderValue > ,
sec_w ebs ocket_key : HeaderValue ,
/// `None` if HTTP/2+ W ebS ockets are used.
sec_websocket_key : Option < HeaderValue > ,
on_upgrade : hyper ::upgrade ::OnUpgrade ,
on_failed_upgrade : F ,
sec_websocket_protocol : Option < HeaderValue > ,
@@ -257,12 +262,12 @@ impl<F> WebSocketUpgrade<F> {
/// ```
/// use axum::{
/// extract::ws::{WebSocketUpgrade, WebSocket},
/// routing::get ,
/// routing::any ,
/// response::{IntoResponse, Response},
/// Router,
/// };
///
/// let app = Router::new().route("/ws", get (handler));
/// let app = Router::new().route("/ws", any (handler));
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
@@ -358,6 +363,7 @@ impl<F> WebSocketUpgrade<F> {
let upgraded = match on_upgrade . await {
Ok ( upgraded ) = > upgraded ,
Err ( err ) = > {
// NOTARY_MODIFICATION: log error
error! ( " Something wrong with on_upgrade: {:?} " , err ) ;
on_failed_upgrade . call ( Error ::new ( err ) ) ;
return ;
@@ -380,25 +386,34 @@ impl<F> WebSocketUpgrade<F> {
callback ( socket ) . await ;
} ) ;
#[ allow(clippy::declare_interior_mutable_const) ]
const UPGRADE : HeaderValue = HeaderValue ::from_static ( " upgrade " ) ;
#[ allow(clippy::declare_interior_mutable_const) ]
const WEBSOCKET : HeaderValue = HeaderValue ::from_static ( " websocket " ) ;
if let Some ( sec_websocket_key ) = & self . sec_websocket_key {
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
let mut builder = Response ::builder ( )
. status ( StatusCode ::SWITCHING_PROTOCOLS )
. header ( header ::CONNECTION , UPGRADE )
. header ( h eader ::UPGRADE , WEBSOCKET )
. header (
header ::SEC_WEBSOCKET_ACCEPT ,
sign ( self . sec_websocket_key . as_bytes ( ) ) ,
) ;
#[ allow(clippy::declare_interior_mutable_const) ]
const UPGRADE : HeaderValue = HeaderValue ::from_ static ( " upgrade " ) ;
#[ allow(clippy::declare_interior_mutable_const) ]
const WEBSOCKET : HeaderValue = H eaderValue ::from_static ( " websocket " ) ;
if let Some ( protocol ) = self . protocol {
builder = builder . header ( hea der ::SEC_WEBSOCKET_PROTOCOL , protocol ) ;
let mut builder = Response ::builder ( )
. status ( StatusCo de ::SWITCHING_PROTOCOLS )
. header ( header ::CONNECTION , UPGRADE )
. header ( header ::UPGRADE , WEBSOCKET )
. header (
header ::SEC_WEBSOCKET_ACCEPT ,
sign ( sec_websocket_key . as_bytes ( ) ) ,
) ;
if let Some ( protocol ) = self . protocol {
builder = builder . header ( header ::SEC_WEBSOCKET_PROTOCOL , protocol ) ;
}
builder . body ( Body ::empty ( ) ) . unwrap ( )
} else {
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just
// respond with a 2XX with an empty body:
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
Response ::new ( Body ::empty ( ) )
}
builder . body ( Body ::empty ( ) ) . unwrap ( )
}
}
@@ -431,7 +446,6 @@ impl OnFailedUpgrade for DefaultOnFailedUpgrade {
fn call ( self , _error : Error ) { }
}
#[ async_trait ]
impl < S > FromRequestParts < S > for WebSocketUpgrade < DefaultOnFailedUpgrade >
where
S : Send + Sync ,
@@ -439,28 +453,51 @@ where
type Rejection = WebSocketUpgradeRejection ;
async fn from_request_parts ( parts : & mut Parts , _state : & S ) -> Result < Self , Self ::Rejection > {
if parts . method ! = Method ::GET {
return Err ( M ethodNotGet . into ( ) ) ;
}
let sec_websocket_key = if parts . version < = Version ::HTTP_11 {
if parts . m ethod ! = Method ::GET {
return Err ( MethodNotGet . into ( ) ) ;
}
if ! header_contains ( & parts . headers , header ::CONNECTION , " upgrade " ) {
return Err ( InvalidConnectionHeader . into ( ) ) ;
}
if ! header_contains ( & parts . headers , header ::CONNECTION , " upgrade " ) {
return Err ( InvalidConnectionHeader . into ( ) ) ;
}
if ! header_eq ( & parts . headers , header ::UPGRADE , " websocket " ) {
return Err ( InvalidUpgradeHeader . into ( ) ) ;
}
if ! header_eq ( & parts . headers , header ::UPGRADE , " websocket " ) {
return Err ( InvalidUpgradeHeader . into ( ) ) ;
}
Some (
parts
. headers
. get ( header ::SEC_WEBSOCKET_KEY )
. ok_or ( WebSocketKeyHeaderMissing ) ?
. clone ( ) ,
)
} else {
if parts . method ! = Method ::CONNECT {
return Err ( MethodNotConnect . into ( ) ) ;
}
// NOTARY_MODIFICATION: ignore http2 feature
// if this feature flag is disabled, we won’ t be receiving an HTTP/2 request to
// begin with.
// #[cfg(feature = "http2")]
// if parts
// .extensions
// .get::<hyper::ext::Protocol>()
// .map_or(true, |p| p.as_str() != "websocket")
// {
// return Err(InvalidProtocolPseudoheader.into());
// }
None
} ;
if ! header_eq ( & parts . headers , header ::SEC_WEBSOCKET_VERSION , " 13 " ) {
return Err ( InvalidWebSocketVersionHeader . into ( ) ) ;
}
let sec_websocket_key = parts
. headers
. get ( header ::SEC_WEBSOCKET_KEY )
. ok_or ( WebSocketKeyHeaderMissing ) ?
. clone ( ) ;
let on_upgrade = parts
. extensions
. remove ::< hyper ::upgrade ::OnUpgrade > ( )
@@ -507,6 +544,7 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) ->
/// See [the module level documentation](self) for more details.
#[ derive(Debug) ]
pub struct WebSocket {
// NOTARY_MODIFICATION: TokioAdapter
inner : WebSocketStream < TokioAdapter < TokioIo < hyper ::upgrade ::Upgraded > > > ,
protocol : Option < HeaderValue > ,
}
@@ -533,11 +571,6 @@ impl WebSocket {
. map_err ( Error ::new )
}
/// Gracefully close this WebSocket.
pub async fn close ( mut self ) -> Result < ( ) , Error > {
self . inner . close ( None ) . await . map_err ( Error ::new )
}
/// Return the selected WebSocket subprotocol, if one has been chosen.
pub fn protocol ( & self ) -> Option < & HeaderValue > {
self . protocol . as_ref ( )
@@ -584,17 +617,137 @@ impl Sink<Message> for WebSocket {
}
}
/// UTF-8 wrapper for [Bytes].
///
/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
/// The following NOTARY_MODIFICATION(s) are required because
/// `async_tungstenite` (v0.28.2) is using an older version of `tungstenite`
/// than `tokio_tungstenite` (v0.26.1). This older version of `tungstenite`
/// (v0.26.0) doesn't have `Utf8Bytes`.
#[ derive(Debug, Clone, PartialEq, Eq, Default) ]
pub struct Utf8Bytes ( axum ::extract ::ws ::Utf8Bytes ) ; // NOTARY_MODIFICATION
impl Utf8Bytes {
/// Creates from a static str.
#[ inline ]
pub const fn from_static ( str : & 'static str ) -> Self {
Self ( axum ::extract ::ws ::Utf8Bytes ::from_static ( str ) ) // NOTARY_MODIFICATION
}
/// Returns as a string slice.
#[ inline ]
pub fn as_str ( & self ) -> & str {
self . 0. as_str ( )
}
fn into_tungstenite ( self ) -> axum ::extract ::ws ::Utf8Bytes {
// NOTARY_MODIFICATION
self . 0
}
}
impl std ::ops ::Deref for Utf8Bytes {
type Target = str ;
/// ```
/// /// Example fn that takes a str slice
/// fn a(s: &str) {}
///
/// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
///
/// // auto-deref as arg
/// a(&data);
///
/// // deref to str methods
/// assert_eq!(data.len(), 6);
/// ```
#[ inline ]
fn deref ( & self ) -> & Self ::Target {
self . as_str ( )
}
}
impl std ::fmt ::Display for Utf8Bytes {
#[ inline ]
fn fmt ( & self , f : & mut std ::fmt ::Formatter < '_ > ) -> std ::fmt ::Result {
f . write_str ( self . as_str ( ) )
}
}
impl TryFrom < Bytes > for Utf8Bytes {
type Error = std ::str ::Utf8Error ;
#[ inline ]
fn try_from ( bytes : Bytes ) -> Result < Self , Self ::Error > {
Ok ( Self ( bytes . try_into ( ) ? ) )
}
}
impl TryFrom < Vec < u8 > > for Utf8Bytes {
type Error = std ::str ::Utf8Error ;
#[ inline ]
fn try_from ( v : Vec < u8 > ) -> Result < Self , Self ::Error > {
Ok ( Self ( v . try_into ( ) ? ) )
}
}
impl From < String > for Utf8Bytes {
#[ inline ]
fn from ( s : String ) -> Self {
Self ( s . into ( ) )
}
}
impl From < & str > for Utf8Bytes {
#[ inline ]
fn from ( s : & str ) -> Self {
Self ( s . into ( ) )
}
}
impl From < & String > for Utf8Bytes {
#[ inline ]
fn from ( s : & String ) -> Self {
Self ( s . into ( ) )
}
}
impl From < Utf8Bytes > for Bytes {
#[ inline ]
fn from ( Utf8Bytes ( bytes ) : Utf8Bytes ) -> Self {
bytes . into ( )
}
}
impl < T > PartialEq < T > for Utf8Bytes
where
for < ' a > & ' a str : PartialEq < T > ,
{
/// ```
/// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
/// assert_eq!(payload, "foo123");
/// assert_eq!(payload, "foo123".to_string());
/// assert_eq!(payload, &"foo123".to_string());
/// assert_eq!(payload, std::borrow::Cow::from("foo123"));
/// ```
#[ inline ]
fn eq ( & self , other : & T ) -> bool {
self . as_str ( ) = = * other
}
}
/// Status code used to indicate why an endpoint is closing the WebSocket
/// connection.
pub type CloseCode = u16 ;
/// A struct representing the close command.
#[ derive(Debug, Clone, Eq, PartialEq) ]
pub struct CloseFrame < ' t > {
pub struct CloseFrame {
/// The reason as a code.
pub code : CloseCode ,
/// The reason as text string.
pub reason : Cow < ' t , str > ,
pub reason : Utf8Bytes ,
}
/// A WebSocket message.
@@ -623,16 +776,16 @@ pub struct CloseFrame<'t> {
#[ derive(Debug, Eq, PartialEq, Clone) ]
pub enum Message {
/// A text WebSocket message
Text ( String ) ,
Text ( Utf8Bytes ) ,
/// A binary WebSocket message
Binary ( Vec < u8 > ) ,
Binary ( Bytes ) ,
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Ping messages will be automatically responded to by the server, so you
/// do not have to worry about dealing with them yourself.
Ping ( Vec < u8 > ) ,
Ping ( Bytes ) ,
/// A pong message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
@@ -640,21 +793,44 @@ pub enum Message {
/// Pong messages will be automatically sent to the client if a ping message
/// is received, so you do not have to worry about constructing them
/// yourself unless you want to implement a [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
Pong ( Vec < u8 > ) ,
Pong ( Bytes ) ,
/// A close message with the optional close frame.
Close ( Option < CloseFrame < 'static > > ) ,
///
/// You may "uncleanly" close a WebSocket connection at any time
/// by simply dropping the [`WebSocket`].
/// However, you may also use the graceful closing protocol, in which
/// 1. peer A sends a close frame, and does not send any further messages;
/// 2. peer B responds with a close frame, and does not send any further
/// messages;
/// 3. peer A processes the remaining messages sent by peer B, before
/// finally
/// 4. both peers close the connection.
///
/// After sending a close frame,
/// you may still read messages,
/// but attempts to send another message will error.
/// After receiving a close frame,
/// axum will automatically respond with a close frame if necessary
/// (you do not have to deal with this yourself).
/// Since no further messages will be received,
/// you may either do nothing
/// or explicitly drop the connection.
Close ( Option < CloseFrame > ) ,
}
/// The following NOTARY_MODIFICATION(s) are required because
/// `async_tungstenite` (v0.28.2) is using an older version of `tungstenite`
/// than `tokio_tungstenite` (v0.26.1).
impl Message {
fn into_tungstenite ( self ) -> ts ::Message {
match self {
Self ::Text ( text ) = > ts ::Message ::Text ( text ) ,
Self ::Binary ( binary ) = > ts ::Message ::Binary ( binary ) ,
Self ::Ping ( ping ) = > ts ::Message ::Ping ( ping ) ,
Self ::Pong ( pong ) = > ts ::Message ::Pong ( pong ) ,
Self ::Text ( text ) = > ts ::Message ::Text ( text . into_tungstenite ( ) . to_string ( ) ) , /* NOTARY_MODIFICATION */
Self ::Binary ( binary ) = > ts ::Message ::Binary ( binary . to_vec ( ) ) , /* NOTARY_MODIFICATION */
Self ::Ping ( ping ) = > ts ::Message ::Ping ( ping . to_vec ( ) ) , /* NOTARY_MODIFICATION */
Self ::Pong ( pong ) = > ts ::Message ::Pong ( pong . to_vec ( ) ) , /* NOTARY_MODIFICATION */
Self ::Close ( Some ( close ) ) = > ts ::Message ::Close ( Some ( ts ::protocol ::CloseFrame {
code : ts ::protocol ::frame ::coding ::CloseCode ::from ( close . code ) ,
reason : close . reason ,
reason : Cow ::Owned ( close. reason . into_tungstenite ( ) . to_string ( ) ) , /* NOTARY_MODIFICATION */
} ) ) ,
Self ::Close ( None ) = > ts ::Message ::Close ( None ) ,
}
@@ -662,13 +838,13 @@ impl Message {
fn from_tungstenite ( message : ts ::Message ) -> Option < Self > {
match message {
ts ::Message ::Text ( text ) = > Some ( Self ::Text ( text ) ) ,
ts ::Message ::Binary ( binary ) = > Some ( Self ::Binary ( binary ) ) ,
ts ::Message ::Ping ( ping ) = > Some ( Self ::Ping ( ping ) ) ,
ts ::Message ::Pong ( pong ) = > Some ( Self ::Pong ( pong ) ) ,
ts ::Message ::Text ( text ) = > Some ( Self ::Text ( Utf8Bytes ( text . into ( ) ) ) ) , /* NOTARY_MODIFICATION */
ts ::Message ::Binary ( binary ) = > Some ( Self ::Binary ( binary . into ( ) ) ) , /* NOTARY_MODIFICATION */
ts ::Message ::Ping ( ping ) = > Some ( Self ::Ping ( ping . into ( ) ) ) , /* NOTARY_MODIFICATION */
ts ::Message ::Pong ( pong ) = > Some ( Self ::Pong ( pong . into ( ) ) ) , /* NOTARY_MODIFICATION */
ts ::Message ::Close ( Some ( close ) ) = > Some ( Self ::Close ( Some ( CloseFrame {
code : close . code . into ( ) ,
reason : close . reason ,
reason : Utf8Bytes ( close. reason . to_string ( ) . into ( ) ) , /* NOTARY_MODIFICATION */
} ) ) ) ,
ts ::Message ::Close ( None ) = > Some ( Self ::Close ( None ) ) ,
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
@@ -678,24 +854,24 @@ impl Message {
}
/// Consume the WebSocket and return it as binary data.
pub fn into_data ( self ) -> Vec < u8 > {
pub fn into_data ( self ) -> Bytes {
match self {
Self ::Text ( string ) = > string . into_bytes ( ) ,
Self ::Text ( string ) = > Bytes ::from ( string ) ,
Self ::Binary ( data ) | Self ::Ping ( data ) | Self ::Pong ( data ) = > data ,
Self ::Close ( None ) = > Vec ::new ( ) ,
Self ::Close ( Some ( frame ) ) = > frame . reason . into_owned ( ) . into_bytes ( ) ,
Self ::Close ( None ) = > Bytes ::new ( ) ,
Self ::Close ( Some ( frame ) ) = > Bytes ::from ( frame . reason ) ,
}
}
/// Attempt to consume the WebSocket message and convert it to a String .
pub fn into_text ( self ) -> Result < String , Error > {
/// Attempt to consume the WebSocket message and convert it to a Utf8Bytes .
pub fn into_text ( self ) -> Result < Utf8Bytes , Error > {
match self {
Self ::Text ( string ) = > Ok ( string ) ,
Self ::Binary ( data ) | Self ::Ping ( data ) | Self ::Pong ( data ) = > Ok ( String ::from_utf8 ( data )
. map_err ( | err | err . utf8_error ( ) )
. map_err ( Error ::new ) ? ) ,
Self ::Close ( None ) = > Ok ( String ::new ( ) ) ,
Self ::Close ( Some ( frame ) ) = > Ok ( frame . reason . into_owned ( ) ),
Self ::Binary ( data ) | Self ::Ping ( data ) | Self ::Pong ( data ) = > {
Ok ( Utf8Bytes ::try_from ( data ) . map_err ( Error ::new ) ? )
}
Self ::Close ( None ) = > Ok ( Utf8Bytes ::default ( ) ) ,
Self ::Close ( Some ( frame ) ) = > Ok ( frame . reason ) ,
}
}
@@ -703,7 +879,7 @@ impl Message {
/// this will try to convert binary data to utf8.
pub fn to_text ( & self ) -> Result < & str , Error > {
match * self {
Self ::Text ( ref string ) = > Ok ( string ) ,
Self ::Text ( ref string ) = > Ok ( string . as_str ( ) ),
Self ::Binary ( ref data ) | Self ::Ping ( ref data ) | Self ::Pong ( ref data ) = > {
Ok ( std ::str ::from_utf8 ( data ) . map_err ( Error ::new ) ? )
}
@@ -711,11 +887,27 @@ impl Message {
Self ::Close ( Some ( ref frame ) ) = > Ok ( & frame . reason ) ,
}
}
/// Create a new text WebSocket message from a stringable.
pub fn text < S > ( string : S ) -> Message
where
S : Into < Utf8Bytes > ,
{
Message ::Text ( string . into ( ) )
}
/// Create a new binary WebSocket message by converting to `Bytes`.
pub fn binary < B > ( bin : B ) -> Message
where
B : Into < Bytes > ,
{
Message ::Binary ( bin . into ( ) )
}
}
impl From < String > for Message {
fn from ( string : String ) -> Self {
Message ::Text ( string )
Message ::Text ( string . into ( ) )
}
}
@@ -727,19 +919,19 @@ impl<'s> From<&'s str> for Message {
impl < ' b > From < & ' b [ u8 ] > for Message {
fn from ( data : & ' b [ u8 ] ) -> Self {
Message ::Binary ( data . into ( ) )
Message ::Binary ( Bytes ::copy_from_slice ( data ) )
}
}
impl From < Vec < u8 > > for Message {
fn from ( data : Vec < u8 > ) -> Self {
Message ::Binary ( data )
Message ::Binary ( data . into ( ) )
}
}
impl From < Message > for Vec < u8 > {
fn from ( msg : Message ) -> Self {
msg . into_data ( )
msg . into_data ( ) . to_vec ( )
}
}
@@ -767,6 +959,13 @@ pub mod rejection {
pub struct MethodNotGet ;
}
define_rejection! {
#[ status = METHOD_NOT_ALLOWED ]
#[ body = " Request method must be `CONNECT` " ]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotConnect ;
}
define_rejection! {
#[ status = BAD_REQUEST ]
#[ body = " Connection header did not include 'upgrade' " ]
@@ -781,6 +980,13 @@ pub mod rejection {
pub struct InvalidUpgradeHeader ;
}
define_rejection! {
#[ status = BAD_REQUEST ]
#[ body = " `:protocol` pseudo-header did not include 'websocket' " ]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidProtocolPseudoheader ;
}
define_rejection! {
#[ status = BAD_REQUEST ]
#[ body = " `Sec-WebSocket-Version` header did not include '13' " ]
@@ -816,8 +1022,10 @@ pub mod rejection {
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet ,
MethodNotConnect ,
InvalidConnectionHeader ,
InvalidUpgradeHeader ,
InvalidProtocolPseudoheader ,
InvalidWebSocketVersionHeader ,
WebSocketKeyHeaderMissing ,
ConnectionNotUpgradable ,
@@ -843,9 +1051,10 @@ pub mod close_code {
pub const PROTOCOL : u16 = 1002 ;
/// Indicates that an endpoint is terminating the connection because it has
/// received a type of data it cannot accept (e.g., an endpoint that
/// understands only text data MAY send this if it receives a binary
/// message).
/// received a type of data that it cannot accept.
///
/// For example, an endpoint MAY send this if it understands only text data,
/// but receives a binary message.
pub const UNSUPPORTED : u16 = 1003 ;
/// Indicates that no status code was included in a closing frame.
@@ -856,14 +1065,18 @@ pub mod close_code {
/// Indicates that an endpoint is terminating the connection because it has
/// received data within a message that was not consistent with the type
/// of the message (e.g., non-UTF-8 RFC3629 data within a text message) .
/// of the message.
///
/// For example, an endpoint received non-UTF-8 RFC3629 data within a text
/// message.
pub const INVALID : u16 = 1007 ;
/// Indicates that an endpoint is terminating the connection because it has
/// received a message that violates its policy. This is a generic
/// status code that can be returned when there is no other more
/// suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a
/// need to hide specific details about the policy.
/// received a message that violates its policy.
///
/// This is a generic status code that can be returned when there is
/// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if
/// there is a need to hide specific details about the policy.
pub const POLICY : u16 = 1008 ;
/// Indicates that an endpoint is terminating the connection because it has
@@ -871,12 +1084,15 @@ pub mod close_code {
pub const SIZE : u16 = 1009 ;
/// Indicates that an endpoint (client) is terminating the connection
/// because it has expected the server to negotiate one or more
/// extension, but the server didn't return them in the response message
/// of the WebSocket handshake. The list of extensions that are needed
/// should be given as the reason for closing. Note that this status
/// code is not used by the server, because it can fail the WebSocket
/// handshake instead.
/// because the server did not respond to extension negotiation
/// correctly.
///
/// Specifically, the client has expected the server to negotiate one or
/// more extension(s), but the server didn't return them in the response
/// message of the WebSocket handshake. The list of extensions that are
/// needed should be given as the reason for closing. Note that this
/// status code is not used by the server, because it can fail the
/// WebSocket handshake instead.
pub const EXTENSION : u16 = 1010 ;
/// Indicates that a server is terminating the connection because it
@@ -896,14 +1112,22 @@ pub mod close_code {
#[ cfg(test) ]
mod tests {
use super ::* ;
use axum ::{ body ::Body , routing ::get , Router } ;
use async_tungstenite ::tungstenite ; // NOTARY_MODIFICATION
use axum ::routing ::any ;
use axum ::{ body ::Body , routing ::get , Router } ; // NOTARY_MODIFICATION
use http ::{ Request , Version } ;
// NOTARY_MODIFICATION: use tower_util instead of tower to make clippy happy
use http_body_util ::BodyExt as _ ;
use hyper_util ::rt ::TokioExecutor ;
use std ::future ::ready ;
use tokio ::{
io ::{ AsyncRead , AsyncWrite } ,
net ::TcpStream ,
} ;
use tower_util ::ServiceExt ;
#[ tokio::test ]
#[ tokio::test ] // NOTARY_MODIFICATION
async fn rejects_http_1_0_requests ( ) {
let svc = get ( | ws : Result < WebSocketUpgrade , WebSocketUpgradeRejection > | {
let svc = any ( | ws : Result < WebSocketUpgrade , WebSocketUpgradeRejection > | {
let rejection = ws . unwrap_err ( ) ;
assert! ( matches! (
rejection ,
@@ -932,7 +1156,7 @@ mod tests {
async fn handler ( ws : WebSocketUpgrade ) -> Response {
ws . on_upgrade ( | _ | async { } )
}
let _ : Router = Router ::new ( ) . route ( " / " , get ( handler ) ) ;
let _ : Router = Router ::new ( ) . route ( " / " , any ( handler ) ) ;
}
#[ allow(dead_code) ]
@@ -941,6 +1165,8 @@ mod tests {
ws . on_failed_upgrade ( | _error : Error | println! ( " oops! " ) )
. on_upgrade ( | _ | async { } )
}
let _ : Router = Router ::new ( ) . route ( " / " , get ( handler ) ) ;
let _ : Router = Router ::new ( ) . route ( " / " , any ( handler ) ) ;
}
// NOTARY_MODIFICATION: removed integration test
}