From c4531feea50c7d7533604ccc98ffa5d0dc0176a1 Mon Sep 17 00:00:00 2001 From: Matthias Seitz Date: Thu, 23 Mar 2023 13:48:37 +0100 Subject: [PATCH] feat(rpc): add ws allowed origins (#1924) --- bin/reth/src/args/rpc_server_args.rs | 7 +++- crates/rpc/rpc-builder/src/lib.rs | 58 +++++++++++++++++++++------- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/bin/reth/src/args/rpc_server_args.rs b/bin/reth/src/args/rpc_server_args.rs index 5c593b84e6..475873d4fa 100644 --- a/bin/reth/src/args/rpc_server_args.rs +++ b/bin/reth/src/args/rpc_server_args.rs @@ -54,6 +54,10 @@ pub struct RpcServerArgs { #[arg(long = "ws.port")] pub ws_port: Option, + /// Origins from which to accept WebSocket requests + #[arg(long = "ws.origins", name = "ws.origins")] + pub ws_allowed_origins: Option, + /// Rpc Modules to be configured for Ws server #[arg(long = "ws.api")] pub ws_api: Option, @@ -200,7 +204,8 @@ impl RpcServerArgs { config = config .with_http_address(socket_address) .with_http(ServerBuilder::new()) - .with_cors(self.http_corsdomain.clone().unwrap_or_default()); + .with_http_cors(self.http_corsdomain.clone()) + .with_ws_cors(self.ws_allowed_origins.clone()); } if self.ws { diff --git a/crates/rpc/rpc-builder/src/lib.rs b/crates/rpc/rpc-builder/src/lib.rs index c87dcf3bfe..973a6f46c0 100644 --- a/crates/rpc/rpc-builder/src/lib.rs +++ b/crates/rpc/rpc-builder/src/lib.rs @@ -619,12 +619,14 @@ where pub struct RpcServerConfig { /// Configs for JSON-RPC Http. http_server_config: Option, - /// Cors Domains + /// Allowed CORS Domains for http http_cors_domains: Option, /// Address where to bind the http server to http_addr: Option, /// Configs for WS server ws_server_config: Option, + /// Allowed CORS Domains for ws. + ws_cors_domains: Option, /// Address where to bind the ws server to ws_addr: Option, /// Configs for JSON-RPC IPC server @@ -661,9 +663,20 @@ impl RpcServerConfig { self } - /// Configure the corsdomains - pub fn with_cors(mut self, cors_domain: String) -> Self { - self.http_cors_domains = Some(cors_domain); + /// Configure the cors domains for http _and_ ws + pub fn with_cors(self, cors_domain: Option) -> Self { + self.with_http_cors(cors_domain.clone()).with_ws_cors(cors_domain) + } + + /// Configure the cors domains for HTTP + pub fn with_http_cors(mut self, cors_domain: Option) -> Self { + self.http_cors_domains = cors_domain; + self + } + + /// Configure the cors domains for WS + pub fn with_ws_cors(mut self, cors_domain: Option) -> Self { + self.ws_cors_domains = cors_domain; self } @@ -773,11 +786,11 @@ impl RpcServerConfig { let http_server = builder.set_middleware(middleware).build(http_socket_addr).await?; server.http_local_addr = http_server.local_addr().ok(); - server.http = Some(HttpServer::WithCors(http_server)); + server.http = Some(WsHttpServer::WithCors(http_server)); } else { let http_server = builder.build(http_socket_addr).await?; server.http_local_addr = http_server.local_addr().ok(); - server.http = Some(HttpServer::Plain(http_server)); + server.http = Some(WsHttpServer::Plain(http_server)); } } @@ -787,9 +800,17 @@ impl RpcServerConfig { ))); if let Some(builder) = self.ws_server_config { - let ws_server = builder.build(ws_socket_addr).await.unwrap(); - server.ws_local_addr = ws_server.local_addr().ok(); - server.ws = Some(ws_server); + if let Some(cors) = self.ws_cors_domains.as_deref().map(cors::create_cors_layer) { + let cors = cors.map_err(|err| RpcError::Custom(err.to_string()))?; + let middleware = tower::ServiceBuilder::new().layer(cors); + let ws_server = builder.set_middleware(middleware).build(ws_socket_addr).await?; + server.http_local_addr = ws_server.local_addr().ok(); + server.ws = Some(WsHttpServer::WithCors(ws_server)); + } else { + let ws_server = builder.build(ws_socket_addr).await?; + server.ws_local_addr = ws_server.local_addr().ok(); + server.ws = Some(WsHttpServer::Plain(ws_server)); + } } if let Some(builder) = self.ipc_server_config { @@ -917,14 +938,14 @@ pub struct RpcServer { /// The address of the ws server ws_local_addr: Option, /// http server - http: Option, + http: Option, /// ws server - ws: Option, + ws: Option, /// ipc server ipc: Option, } /// Http Servers Enum -pub enum HttpServer { +pub enum WsHttpServer { /// Http server Plain(Server), /// Http server with cors @@ -970,10 +991,10 @@ impl RpcServer { self.http.and_then(|server| http.map(|module| (server, module))) { match server { - HttpServer::Plain(server) => { + WsHttpServer::Plain(server) => { handle.http = Some(server.start(module)?); } - HttpServer::WithCors(server) => { + WsHttpServer::WithCors(server) => { handle.http = Some(server.start(module)?); } } @@ -981,7 +1002,14 @@ impl RpcServer { if let Some((server, module)) = self.ws.and_then(|server| ws.map(|module| (server, module))) { - handle.ws = Some(server.start(module)?); + match server { + WsHttpServer::Plain(server) => { + handle.ws = Some(server.start(module)?); + } + WsHttpServer::WithCors(server) => { + handle.ws = Some(server.start(module)?); + } + } } if let Some((server, module)) =