diff --git a/crates/rpc/rpc-builder/src/auth.rs b/crates/rpc/rpc-builder/src/auth.rs index b1a4f4166b..777081a7e6 100644 --- a/crates/rpc/rpc-builder/src/auth.rs +++ b/crates/rpc/rpc-builder/src/auth.rs @@ -1,9 +1,13 @@ -use crate::error::{RpcError, ServerKind}; +use crate::{ + error::{RpcError, ServerKind}, + middleware::RethRpcMiddleware, +}; use http::header::AUTHORIZATION; use jsonrpsee::{ core::{client::SubscriptionClientT, RegisterMethodError}, http_client::HeaderMap, server::{AlreadyStoppedError, RpcModule}, + ws_client::RpcServiceBuilder, Methods, }; use reth_rpc_api::servers::*; @@ -21,7 +25,7 @@ pub use reth_ipc::server::Builder as IpcServerBuilder; /// Server configuration for the auth server. #[derive(Debug)] -pub struct AuthServerConfig { +pub struct AuthServerConfig { /// Where the server should listen. pub(crate) socket_addr: SocketAddr, /// The secret for the auth layer of the server. @@ -32,6 +36,8 @@ pub struct AuthServerConfig { pub(crate) ipc_server_config: Option>, /// IPC endpoint pub(crate) ipc_endpoint: Option, + /// Configurable RPC middleware + pub(crate) rpc_middleware: RpcMiddleware, } // === impl AuthServerConfig === @@ -41,24 +47,51 @@ impl AuthServerConfig { pub const fn builder(secret: JwtSecret) -> AuthServerConfigBuilder { AuthServerConfigBuilder::new(secret) } - +} +impl AuthServerConfig { /// Returns the address the server will listen on. pub const fn address(&self) -> SocketAddr { self.socket_addr } + /// Configures the rpc middleware. + pub fn with_rpc_middleware(self, rpc_middleware: T) -> AuthServerConfig { + let Self { socket_addr, secret, server_config, ipc_server_config, ipc_endpoint, .. } = self; + AuthServerConfig { + socket_addr, + secret, + server_config, + ipc_server_config, + ipc_endpoint, + rpc_middleware, + } + } + /// Convenience function to start a server in one step. - pub async fn start(self, module: AuthRpcModule) -> Result { - let Self { socket_addr, secret, server_config, ipc_server_config, ipc_endpoint } = self; + pub async fn start(self, module: AuthRpcModule) -> Result + where + RpcMiddleware: RethRpcMiddleware, + { + let Self { + socket_addr, + secret, + server_config, + ipc_server_config, + ipc_endpoint, + rpc_middleware, + } = self; // Create auth middleware. let middleware = tower::ServiceBuilder::new().layer(AuthLayer::new(JwtAuthValidator::new(secret))); + let rpc_middleware = RpcServiceBuilder::default().layer(rpc_middleware); + // By default, both http and ws are enabled. let server = ServerBuilder::new() .set_config(server_config.build()) .set_http_middleware(middleware) + .set_rpc_middleware(rpc_middleware) .build(socket_addr) .await .map_err(|err| RpcError::server_error(err, ServerKind::Auth(socket_addr)))?; @@ -86,12 +119,13 @@ impl AuthServerConfig { /// Builder type for configuring an `AuthServerConfig`. #[derive(Debug)] -pub struct AuthServerConfigBuilder { +pub struct AuthServerConfigBuilder { socket_addr: Option, secret: JwtSecret, server_config: Option, ipc_server_config: Option>, ipc_endpoint: Option, + rpc_middleware: RpcMiddleware, } // === impl AuthServerConfigBuilder === @@ -105,6 +139,22 @@ impl AuthServerConfigBuilder { server_config: None, ipc_server_config: None, ipc_endpoint: None, + rpc_middleware: Identity::new(), + } + } +} + +impl AuthServerConfigBuilder { + /// Configures the rpc middleware. + pub fn with_rpc_middleware(self, rpc_middleware: T) -> AuthServerConfigBuilder { + let Self { socket_addr, secret, server_config, ipc_server_config, ipc_endpoint, .. } = self; + AuthServerConfigBuilder { + socket_addr, + secret, + server_config, + ipc_server_config, + ipc_endpoint, + rpc_middleware, } } @@ -150,7 +200,7 @@ impl AuthServerConfigBuilder { } /// Build the `AuthServerConfig`. - pub fn build(self) -> AuthServerConfig { + pub fn build(self) -> AuthServerConfig { AuthServerConfig { socket_addr: self.socket_addr.unwrap_or_else(|| { SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), constants::DEFAULT_AUTH_PORT) @@ -182,6 +232,7 @@ impl AuthServerConfigBuilder { .set_id_provider(EthSubscriptionIdProvider::default()) }), ipc_endpoint: self.ipc_endpoint, + rpc_middleware: self.rpc_middleware, } } }