diff --git a/crates/rpc/rpc-builder/src/error.rs b/crates/rpc/rpc-builder/src/error.rs index 276d6ba62b..f83fdeed65 100644 --- a/crates/rpc/rpc-builder/src/error.rs +++ b/crates/rpc/rpc-builder/src/error.rs @@ -1,15 +1,18 @@ use std::net::SocketAddr; +use crate::RethRpcModule; use jsonrpsee::core::Error as JsonRpseeError; use std::{io, io::ErrorKind}; /// Rpc server kind. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum ServerKind { /// Http. Http(SocketAddr), /// Websocket. WS(SocketAddr), + /// WS and http on the same port + WsHttp(SocketAddr), /// Auth. Auth(SocketAddr), } @@ -19,6 +22,7 @@ impl std::fmt::Display for ServerKind { match self { ServerKind::Http(addr) => write!(f, "{addr} (HTTP-RPC server)"), ServerKind::WS(addr) => write!(f, "{addr} (WS-RPC server)"), + ServerKind::WsHttp(addr) => write!(f, "{addr} (WS-HTTP-RPC server)"), ServerKind::Auth(addr) => write!(f, "{addr} (AUTH server)"), } } @@ -38,6 +42,9 @@ pub enum RpcError { /// IO error. error: io::Error, }, + /// Http and WS server configured on the same port but with conflicting settings. + #[error(transparent)] + WsHttpSamePortError(#[from] WsHttpSamePortError), /// Custom error. #[error("{0}")] Custom(String), @@ -62,3 +69,19 @@ impl RpcError { } } } + +/// Errors when trying to launch ws and http server on the same port. +#[derive(Debug, thiserror::Error)] +pub enum WsHttpSamePortError { + /// Ws and http server configured on same port but with different cors domains. + #[error("CORS domains for http and ws are different, but they are on the same port")] + ConflictingCorsDomains, + /// Ws and http server configured on same port but with different modules. + #[error("Different api modules for http and ws on the same port is currently not supported: http: {http_modules:?}, ws: {ws_modules:?}")] + ConflictingModules { + /// Http modules. + http_modules: Vec, + /// Ws modules. + ws_modules: Vec, + }, +} diff --git a/crates/rpc/rpc-builder/src/lib.rs b/crates/rpc/rpc-builder/src/lib.rs index 7a98c0b053..12240f8e02 100644 --- a/crates/rpc/rpc-builder/src/lib.rs +++ b/crates/rpc/rpc-builder/src/lib.rs @@ -116,7 +116,7 @@ use reth_tasks::TaskSpawner; use reth_transaction_pool::TransactionPool; use serde::{Deserialize, Serialize, Serializer}; use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, fmt, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, str::FromStr, @@ -142,8 +142,8 @@ mod eth; pub mod constants; // re-export for convenience -use crate::auth::AuthRpcModule; pub use crate::eth::{EthConfig, EthHandlers}; +use crate::{auth::AuthRpcModule, error::WsHttpSamePortError}; pub use jsonrpsee::server::ServerBuilder; pub use reth_ipc::server::{Builder as IpcServerBuilder, Endpoint}; @@ -273,7 +273,7 @@ where let Self { client, pool, network, executor, events } = self; - let TransportRpcModuleConfig { http, ws, ipc, config } = module_config; + let TransportRpcModuleConfig { http, ws, ipc, config } = module_config.clone(); let mut registry = RethModuleRegistry::new( client, @@ -284,6 +284,7 @@ where config.unwrap_or_default(), ); + modules.config = module_config; modules.http = registry.maybe_module(http.as_ref()); modules.ws = registry.maybe_module(ws.as_ref()); modules.ipc = registry.maybe_module(ipc.as_ref()); @@ -303,7 +304,7 @@ where let Self { client, pool, network, executor, events } = self; if !module_config.is_empty() { - let TransportRpcModuleConfig { http, ws, ipc, config } = module_config; + let TransportRpcModuleConfig { http, ws, ipc, config } = module_config.clone(); let mut registry = RethModuleRegistry::new( client, @@ -314,6 +315,7 @@ where config.unwrap_or_default(), ); + modules.config = module_config; modules.http = registry.maybe_module(http.as_ref()); modules.ws = registry.maybe_module(ws.as_ref()); modules.ipc = registry.maybe_module(ipc.as_ref()); @@ -424,6 +426,14 @@ impl RpcModuleSelection { Ok(RpcModuleSelection::Selection(selection)) } + /// Returns true if no selection is configured + pub fn is_empty(&self) -> bool { + match self { + RpcModuleSelection::Selection(sel) => sel.is_empty(), + _ => false, + } + } + /// Creates a new [RpcModule] based on the configured reth modules. /// /// Note: This will always create new instance of the module handlers and is therefor only @@ -466,6 +476,21 @@ impl RpcModuleSelection { RpcModuleSelection::Standard => Self::STANDARD_MODULES.to_vec(), } } + + /// Returns true if both selections are identical. + fn are_identical(http: Option<&RpcModuleSelection>, ws: Option<&RpcModuleSelection>) -> bool { + match (http, ws) { + (Some(http), Some(ws)) => { + let http = http.clone().iter_selection().collect::>(); + let ws = ws.clone().iter_selection().collect::>(); + + http == ws + } + (Some(http), None) => http.is_empty(), + (None, Some(ws)) => ws.is_empty(), + _ => true, + } + } } impl From for RpcModuleSelection @@ -850,7 +875,7 @@ impl RpcServerConfig { /// To set a custom [IdProvider], please use [Self::with_id_provider]. pub fn with_http(mut self, config: ServerBuilder) -> Self { self.http_server_config = - Some(config.http_only().set_id_provider(EthSubscriptionIdProvider::default())); + Some(config.set_id_provider(EthSubscriptionIdProvider::default())); self } @@ -876,8 +901,7 @@ impl RpcServerConfig { /// Note: this always configures an [EthSubscriptionIdProvider] [IdProvider] for convenience. /// To set a custom [IdProvider], please use [Self::with_id_provider]. pub fn with_ws(mut self, config: ServerBuilder) -> Self { - self.ws_server_config = - Some(config.ws_only().set_id_provider(EthSubscriptionIdProvider::default())); + self.ws_server_config = Some(config.set_id_provider(EthSubscriptionIdProvider::default())); self } @@ -966,64 +990,100 @@ impl RpcServerConfig { self.build().await?.start(modules).await } - /// Finalize the configuration of the server(s). + /// Builds the ws and http server(s). /// - /// This consumes the builder and returns a server. - /// - /// Note: The server ist not started and does nothing unless polled, See also [RpcServer::start] - pub async fn build(self) -> Result { - let mut server = RpcServer::empty(); - + /// If both are on the same port, they are combined into one server. + async fn build_ws_http(&mut self) -> Result { let http_socket_addr = self.http_addr.unwrap_or(SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::UNSPECIFIED, DEFAULT_HTTP_RPC_PORT, ))); - if let Some(builder) = self.http_server_config { - if let Some(cors) = self.http_cors_domains.as_deref().map(cors::create_cors_layer) { - let cors = cors.map_err(|err| RpcError::Custom(err.to_string()))?; - let middleware = tower::ServiceBuilder::new().layer(cors); - let http_server = - builder.set_middleware(middleware).build(http_socket_addr).await.map_err( - |err| { - RpcError::from_jsonrpsee_error(err, ServerKind::Http(http_socket_addr)) - }, - )?; - server.http_local_addr = http_server.local_addr().ok(); - server.http = Some(WsHttpServer::WithCors(http_server)); - } else { - let http_server = builder.build(http_socket_addr).await.map_err(|err| { - RpcError::from_jsonrpsee_error(err, ServerKind::Http(http_socket_addr)) - })?; - server.http_local_addr = http_server.local_addr().ok(); - server.http = Some(WsHttpServer::Plain(http_server)); - } - } - let ws_socket_addr = self.ws_addr.unwrap_or(SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::UNSPECIFIED, DEFAULT_WS_RPC_PORT, ))); - if let Some(builder) = self.ws_server_config { - if let Some(cors) = self.ws_cors_domains.as_deref().map(cors::create_cors_layer) { - let cors = cors.map_err(|err| RpcError::Custom(err.to_string()))?; - let middleware = tower::ServiceBuilder::new().layer(cors); - let ws_server = - builder.set_middleware(middleware).build(ws_socket_addr).await.map_err( - |err| RpcError::from_jsonrpsee_error(err, ServerKind::WS(ws_socket_addr)), - )?; - server.http_local_addr = ws_server.local_addr().ok(); - server.ws = Some(WsHttpServer::WithCors(ws_server)); - } else { - let ws_server = builder.build(ws_socket_addr).await.map_err(|err| { - RpcError::from_jsonrpsee_error(err, ServerKind::WS(ws_socket_addr)) - })?; - server.ws_local_addr = ws_server.local_addr().ok(); - server.ws = Some(WsHttpServer::Plain(ws_server)); + // If both are configured on the same port, we combine them into one server. + if self.http_addr == self.ws_addr && + self.http_server_config.is_some() && + self.ws_server_config.is_some() + { + let cors = match (self.ws_cors_domains.as_ref(), self.http_cors_domains.as_ref()) { + (Some(_), Some(_)) => { + return Err(WsHttpSamePortError::ConflictingCorsDomains.into()) + } + (None, cors @ Some(_)) => cors, + (cors @ Some(_), None) => cors, + _ => None, } + .cloned(); + + // we merge this into one server using the http setup + self.ws_server_config.take(); + + let builder = self.http_server_config.take().expect("is set; qed"); + let (server, addr) = WsHttpServerKind::build( + builder, + http_socket_addr, + cors, + ServerKind::WsHttp(http_socket_addr), + ) + .await?; + return Ok(WsHttpServer { + http_local_addr: Some(addr), + ws_local_addr: Some(addr), + server: WsHttpServers::SamePort(server), + }) } + let mut http_local_addr = None; + let mut http_server = None; + + let mut ws_local_addr = None; + let mut ws_server = None; + if let Some(builder) = self.ws_server_config.take() { + let builder = builder.ws_only(); + let (server, addr) = WsHttpServerKind::build( + builder, + ws_socket_addr, + self.ws_cors_domains.take(), + ServerKind::WS(ws_socket_addr), + ) + .await?; + ws_local_addr = Some(addr); + ws_server = Some(server); + } + + if let Some(builder) = self.http_server_config.take() { + let builder = builder.http_only(); + let (server, addr) = WsHttpServerKind::build( + builder, + http_socket_addr, + self.http_cors_domains.take(), + ServerKind::Http(http_socket_addr), + ) + .await?; + http_local_addr = Some(addr); + http_server = Some(server); + } + + Ok(WsHttpServer { + http_local_addr, + ws_local_addr, + server: WsHttpServers::DifferentPort { http: http_server, ws: ws_server }, + }) + } + + /// Finalize the configuration of the server(s). + /// + /// This consumes the builder and returns a server. + /// + /// Note: The server ist not started and does nothing unless polled, See also [RpcServer::start] + pub async fn build(mut self) -> Result { + let mut server = RpcServer::empty(); + server.ws_http = self.build_ws_http().await?; + if let Some(builder) = self.ipc_server_config { let ipc_path = self .ipc_endpoint @@ -1120,11 +1180,27 @@ impl TransportRpcModuleConfig { pub fn ipc(&self) -> Option<&RpcModuleSelection> { self.ipc.as_ref() } + + /// Ensures that both http and ws are configured and that they are configured to use the same + /// port. + fn ensure_ws_http_identical(&self) -> Result<(), WsHttpSamePortError> { + if RpcModuleSelection::are_identical(self.http.as_ref(), self.ws.as_ref()) { + Ok(()) + } else { + let http_modules = + self.http.clone().map(RpcModuleSelection::into_selection).unwrap_or_default(); + let ws_modules = + self.ws.clone().map(RpcModuleSelection::into_selection).unwrap_or_default(); + Err(WsHttpSamePortError::ConflictingModules { http_modules, ws_modules }) + } + } } /// Holds installed modules per transport type. #[derive(Debug, Default)] pub struct TransportRpcModules { + /// The original config + config: TransportRpcModuleConfig, /// rpcs module for http http: Option>, /// rpcs module for ws @@ -1142,42 +1218,145 @@ impl TransportRpcModules<()> { } } -/// Container type for each transport ie. http, ws, and ipc server -pub struct RpcServer { +/// Container type for ws and http servers in all possible combinations. +#[derive(Default)] +struct WsHttpServer { /// The address of the http server http_local_addr: Option, /// The address of the ws server ws_local_addr: Option, - /// http server - http: Option, - /// ws server - ws: Option, - /// ipc server - ipc: Option, + /// Configured ws,http servers + server: WsHttpServers, } + +/// Enum for holding the http and ws servers in all possible combinations. +enum WsHttpServers { + /// Both servers are on the same port + SamePort(WsHttpServerKind), + /// Servers are on different ports + DifferentPort { http: Option, ws: Option }, +} + +// === impl WsHttpServers === + +impl WsHttpServers { + /// Starts the servers and returns the handles (http, ws) + async fn start( + self, + http_module: Option>, + ws_module: Option>, + config: &TransportRpcModuleConfig, + ) -> Result<(Option, Option), RpcError> { + let mut http_handle = None; + let mut ws_handle = None; + match self { + WsHttpServers::SamePort(both) => { + // Make sure http and ws modules are identical, since we currently can't run + // different modules on same server + config.ensure_ws_http_identical()?; + + if let Some(module) = http_module.or(ws_module) { + let handle = both.start(module).await?; + http_handle = Some(handle.clone()); + ws_handle = Some(handle); + } + } + WsHttpServers::DifferentPort { http, ws } => { + if let Some((server, module)) = + http.and_then(|server| http_module.map(|module| (server, module))) + { + http_handle = Some(server.start(module).await?); + } + if let Some((server, module)) = + ws.and_then(|server| ws_module.map(|module| (server, module))) + { + ws_handle = Some(server.start(module).await?); + } + } + } + + Ok((http_handle, ws_handle)) + } +} + +impl Default for WsHttpServers { + fn default() -> Self { + Self::DifferentPort { http: None, ws: None } + } +} + /// Http Servers Enum -pub enum WsHttpServer { +enum WsHttpServerKind { /// Http server Plain(Server), /// Http server with cors WithCors(Server>), } +// === impl WsHttpServerKind === + +impl WsHttpServerKind { + /// Starts the server and returns the handle + async fn start(self, module: RpcModule<()>) -> Result { + match self { + WsHttpServerKind::Plain(server) => Ok(server.start(module)?), + WsHttpServerKind::WithCors(server) => Ok(server.start(module)?), + } + } + + /// Builds + async fn build( + builder: ServerBuilder, + socket_addr: SocketAddr, + cors_domains: Option, + server_kind: ServerKind, + ) -> Result<(Self, SocketAddr), RpcError> { + if let Some(cors) = cors_domains.as_deref().map(cors::create_cors_layer) { + let cors = cors.map_err(|err| RpcError::Custom(err.to_string()))?; + let middleware = tower::ServiceBuilder::new().layer(cors); + let server = builder + .set_middleware(middleware) + .build(socket_addr) + .await + .map_err(|err| RpcError::from_jsonrpsee_error(err, server_kind))?; + let local_addr = server.local_addr()?; + let server = WsHttpServerKind::WithCors(server); + Ok((server, local_addr)) + } else { + let server = builder + .build(socket_addr) + .await + .map_err(|err| RpcError::from_jsonrpsee_error(err, server_kind))?; + let local_addr = server.local_addr()?; + let server = WsHttpServerKind::Plain(server); + Ok((server, local_addr)) + } + } +} + +/// Container type for each transport ie. http, ws, and ipc server +pub struct RpcServer { + /// Configured ws,http servers + ws_http: WsHttpServer, + /// ipc server + ipc: Option, +} + // === impl RpcServer === impl RpcServer { fn empty() -> RpcServer { - RpcServer { http_local_addr: None, ws_local_addr: None, http: None, ws: None, ipc: None } + RpcServer { ws_http: Default::default(), ipc: None } } /// Returns the [`SocketAddr`] of the http server if started. pub fn http_local_addr(&self) -> Option { - self.http_local_addr + self.ws_http.http_local_addr } /// Returns the [`SocketAddr`] of the ws server if started. pub fn ws_local_addr(&self) -> Option { - self.ws_local_addr + self.ws_http.ws_local_addr } /// Returns the [`Endpoint`] of the ipc server if started. @@ -1189,49 +1368,28 @@ impl RpcServer { /// /// This returns an [RpcServerHandle] that's connected to the server task(s) until the server is /// stopped or the [RpcServerHandle] is dropped. - #[instrument(name = "start", skip_all, fields(http = ?self.http_local_addr, ws = ?self.ws_local_addr, ipc = ?self.ipc_endpoint().map(|ipc|ipc.path())), target = "rpc", level = "TRACE")] + #[instrument(name = "start", skip_all, fields(http = ?self.http_local_addr(), ws = ?self.ws_local_addr(), ipc = ?self.ipc_endpoint().map(|ipc|ipc.path())), target = "rpc", level = "TRACE")] pub async fn start( self, modules: TransportRpcModules<()>, ) -> Result { trace!(target: "rpc", "staring RPC server"); - let TransportRpcModules { http, ws, ipc } = modules; + let Self { ws_http, ipc: ipc_server } = self; + let TransportRpcModules { config, http, ws, ipc } = modules; let mut handle = RpcServerHandle { - http_local_addr: self.http_local_addr, - ws_local_addr: self.ws_local_addr, + http_local_addr: ws_http.http_local_addr, + ws_local_addr: ws_http.ws_local_addr, http: None, ws: None, ipc: None, }; - // Start all servers - if let Some((server, module)) = - self.http.and_then(|server| http.map(|module| (server, module))) - { - match server { - WsHttpServer::Plain(server) => { - handle.http = Some(server.start(module)?); - } - WsHttpServer::WithCors(server) => { - handle.http = Some(server.start(module)?); - } - } - } - - if let Some((server, module)) = self.ws.and_then(|server| ws.map(|module| (server, module))) - { - match server { - WsHttpServer::Plain(server) => { - handle.ws = Some(server.start(module)?); - } - WsHttpServer::WithCors(server) => { - handle.ws = Some(server.start(module)?); - } - } - } + let (http, ws) = ws_http.server.start(http, ws, &config).await?; + handle.http = http; + handle.ws = ws; if let Some((server, module)) = - self.ipc.and_then(|server| ipc.map(|module| (server, module))) + ipc_server.and_then(|server| ipc.map(|module| (server, module))) { handle.ipc = Some(server.start(module).await?); } @@ -1243,8 +1401,8 @@ impl RpcServer { impl fmt::Debug for RpcServer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RpcServer") - .field("http", &self.http.is_some()) - .field("ws", &self.ws.is_some()) + .field("http", &self.ws_http.http_local_addr.is_some()) + .field("ws", &self.ws_http.http_local_addr.is_some()) .field("ipc", &self.ipc.is_some()) .finish() } @@ -1324,8 +1482,8 @@ impl RpcServerHandle { } } -impl std::fmt::Debug for RpcServerHandle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Debug for RpcServerHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RpcServerHandle") .field("http", &self.http.is_some()) .field("ws", &self.ws.is_some()) @@ -1338,6 +1496,22 @@ impl std::fmt::Debug for RpcServerHandle { mod tests { use super::*; + #[test] + fn identical_selection() { + assert!(RpcModuleSelection::are_identical( + Some(&RpcModuleSelection::All), + Some(&RpcModuleSelection::All), + )); + assert!(!RpcModuleSelection::are_identical( + Some(&RpcModuleSelection::All), + Some(&RpcModuleSelection::Standard), + )); + assert!(RpcModuleSelection::are_identical( + Some(&RpcModuleSelection::Selection(RpcModuleSelection::Standard.into_selection())), + Some(&RpcModuleSelection::Standard), + )); + } + #[test] fn test_rpc_module_str() { macro_rules! assert_rpc_module { diff --git a/crates/rpc/rpc-builder/tests/it/startup.rs b/crates/rpc/rpc-builder/tests/it/startup.rs index 6905e8da64..6ca0ad0138 100644 --- a/crates/rpc/rpc-builder/tests/it/startup.rs +++ b/crates/rpc/rpc-builder/tests/it/startup.rs @@ -1,15 +1,17 @@ //! Startup tests -use crate::utils::{launch_http, launch_ws, test_rpc_builder}; +use crate::utils::{ + launch_http, launch_http_ws_same_port, launch_ws, test_address, test_rpc_builder, +}; use reth_rpc_builder::{ - error::{RpcError, ServerKind}, + error::{RpcError, ServerKind, WsHttpSamePortError}, RethRpcModule, RpcServerConfig, TransportRpcModuleConfig, }; use std::io; -fn is_addr_in_use_kind(err: RpcError, kind: ServerKind) -> bool { +fn is_addr_in_use_kind(err: &RpcError, kind: ServerKind) -> bool { match err { RpcError::AddressAlreadyInUse { kind: k, error } => { - k == kind && error.kind() == io::ErrorKind::AddrInUse + *k == kind && error.kind() == io::ErrorKind::AddrInUse } _ => false, } @@ -24,7 +26,8 @@ async fn test_http_addr_in_use() { let result = server .start_server(RpcServerConfig::http(Default::default()).with_http_address(addr)) .await; - assert!(is_addr_in_use_kind(result.unwrap_err(), ServerKind::Http(addr))); + let err = result.unwrap_err(); + assert!(is_addr_in_use_kind(&err, ServerKind::Http(addr)), "{err:?}"); } #[tokio::test(flavor = "multi_thread")] @@ -35,5 +38,37 @@ async fn test_ws_addr_in_use() { let server = builder.build(TransportRpcModuleConfig::set_ws(vec![RethRpcModule::Admin])); let result = server.start_server(RpcServerConfig::ws(Default::default()).with_ws_address(addr)).await; - assert!(is_addr_in_use_kind(result.unwrap_err(), ServerKind::WS(addr))); + let err = result.unwrap_err(); + assert!(is_addr_in_use_kind(&err, ServerKind::WS(addr)), "{err:?}"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_launch_same_port() { + let handle = launch_http_ws_same_port(vec![RethRpcModule::Admin]).await; + let ws_addr = handle.ws_local_addr().unwrap(); + let http_addr = handle.http_local_addr().unwrap(); + assert_eq!(ws_addr, http_addr); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_launch_same_port_different_modules() { + let builder = test_rpc_builder(); + let server = builder.build( + TransportRpcModuleConfig::set_ws(vec![RethRpcModule::Admin]) + .with_http(vec![RethRpcModule::Eth]), + ); + let addr = test_address(); + let res = server + .start_server( + RpcServerConfig::ws(Default::default()) + .with_ws_address(addr) + .with_http(Default::default()) + .with_http_address(addr), + ) + .await; + let err = res.unwrap_err(); + assert!(matches!( + err, + RpcError::WsHttpSamePortError(WsHttpSamePortError::ConflictingModules { .. }) + )); } diff --git a/crates/rpc/rpc-builder/tests/it/utils.rs b/crates/rpc/rpc-builder/tests/it/utils.rs index 52b0610ed7..867c4e4b7c 100644 --- a/crates/rpc/rpc-builder/tests/it/utils.rs +++ b/crates/rpc/rpc-builder/tests/it/utils.rs @@ -75,6 +75,24 @@ pub async fn launch_http_ws(modules: impl Into) -> RpcServer .unwrap() } +/// Launches a new server with http and ws and with the given modules on the same port. +pub async fn launch_http_ws_same_port(modules: impl Into) -> RpcServerHandle { + let builder = test_rpc_builder(); + let modules = modules.into(); + let server = + builder.build(TransportRpcModuleConfig::set_ws(modules.clone()).with_http(modules)); + let addr = test_address(); + server + .start_server( + RpcServerConfig::ws(Default::default()) + .with_ws_address(addr) + .with_http(Default::default()) + .with_http_address(addr), + ) + .await + .unwrap() +} + /// Returns an [RpcModuleBuilder] with testing components. pub fn test_rpc_builder() -> RpcModuleBuilder< NoopProvider,