diff --git a/Cargo.lock b/Cargo.lock index de0f8faa0c..fc33d1c5a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4969,6 +4969,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-stream", "tokio-util 0.7.7", "tower", "tracing", diff --git a/crates/rpc/ipc/Cargo.toml b/crates/rpc/ipc/Cargo.toml index b1d6ee9851..bf270ef655 100644 --- a/crates/rpc/ipc/Cargo.toml +++ b/crates/rpc/ipc/Cargo.toml @@ -16,6 +16,7 @@ futures = "0.3" parity-tokio-ipc = "0.9.0" tokio = { version = "1", features = ["net", "time", "rt-multi-thread"] } tokio-util = { version = "0.7", features = ["codec"] } +tokio-stream = "0.1" async-trait = "0.1" pin-project = "1.0" tower = "0.4" @@ -29,6 +30,7 @@ thiserror = "1.0.37" [dev-dependencies] tracing-test = "0.2" +tokio-stream = { version = "0.1", features = ["sync"] } [features] client = ["jsonrpsee/client", "jsonrpsee/async-client"] diff --git a/crates/rpc/ipc/src/server/ipc.rs b/crates/rpc/ipc/src/server/ipc.rs index 81d5c3226a..c03f3ca385 100644 --- a/crates/rpc/ipc/src/server/ipc.rs +++ b/crates/rpc/ipc/src/server/ipc.rs @@ -10,9 +10,14 @@ use jsonrpsee::{ server::{ logger, logger::{Logger, TransportProtocol}, + IdProvider, }, - types::{error::ErrorCode, ErrorObject, Id, InvalidRequest, Notification, Params, Request}, - MethodCallback, Methods, + types::{ + error::{reject_too_many_subscriptions, ErrorCode}, + ErrorObject, Id, InvalidRequest, Notification, Params, Request, + }, + BoundedSubscriptions, CallOrSubscription, MethodCallback, MethodSink, Methods, + SubscriptionState, }; use std::sync::Arc; use tokio::sync::OwnedSemaphorePermit; @@ -32,16 +37,19 @@ pub(crate) struct CallData<'a, L: Logger> { conn_id: usize, logger: &'a L, methods: &'a Methods, + id_provider: &'a dyn IdProvider, + sink: &'a MethodSink, max_response_body_size: u32, max_log_length: u32, request_start: L::Instant, + bounded_subscriptions: BoundedSubscriptions, } // Batch responses must be sent back as a single message so we read the results from each // request in the batch and read the results off of a new channel, `rx_batch`, and then send the // complete batch response back to the client over `tx`. #[instrument(name = "batch", skip(b), level = "TRACE")] -pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> String +pub(crate) async fn process_batch_request(b: Batch<'_, L>) -> Option where L: Logger, { @@ -56,7 +64,9 @@ where .into_iter() .filter_map(|v| { if let Ok(req) = serde_json::from_str::>(v.get()) { - Some(Either::Right(execute_call(req, call.clone()))) + Some(Either::Right(async { + execute_call(req, call.clone()).await.into_response() + })) } else if let Ok(_notif) = serde_json::from_str::>(v.get()) { // notifications should not be answered. got_notif = true; @@ -77,31 +87,31 @@ where while let Some(response) = pending_calls.next().await { if let Err(too_large) = batch_response.append(&response) { - return too_large + return Some(too_large) } } if got_notif && batch_response.is_empty() { - String::new() + None } else { - batch_response.finish() + Some(batch_response.finish()) } } else { - batch_response_error(Id::Null, ErrorObject::from(ErrorCode::ParseError)) + Some(batch_response_error(Id::Null, ErrorObject::from(ErrorCode::ParseError))) } } pub(crate) async fn process_single_request( data: Vec, call: CallData<'_, L>, -) -> MethodResponse { +) -> Option { if let Ok(req) = serde_json::from_slice::>(&data) { - execute_call_with_tracing(req, call).await - } else if let Ok(notif) = serde_json::from_slice::>(&data) { - execute_notification(notif, call.max_log_length) + Some(execute_call_with_tracing(req, call).await) + } else if serde_json::from_slice::>(&data).is_ok() { + None } else { let (id, code) = prepare_error(&data); - MethodResponse::error(id, ErrorObject::from(code)) + Some(CallOrSubscription::Call(MethodResponse::error(id, ErrorObject::from(code)))) } } @@ -109,21 +119,24 @@ pub(crate) async fn process_single_request( pub(crate) async fn execute_call_with_tracing<'a, L: Logger>( req: Request<'a>, call: CallData<'_, L>, -) -> MethodResponse { +) -> CallOrSubscription { execute_call(req, call).await } pub(crate) async fn execute_call( req: Request<'_>, call: CallData<'_, L>, -) -> MethodResponse { +) -> CallOrSubscription { let CallData { methods, - logger, max_response_body_size, max_log_length, conn_id, + id_provider, + sink, + logger, request_start, + bounded_subscriptions, } = call; rx_log_from_json(&req, call.max_log_length); @@ -140,7 +153,8 @@ pub(crate) async fn execute_call( logger::MethodKind::Unknown, TransportProtocol::Http, ); - MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)) + let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)); + CallOrSubscription::Call(response) } Some((name, method)) => match method { MethodCallback::Sync(callback) => { @@ -150,7 +164,8 @@ pub(crate) async fn execute_call( logger::MethodKind::MethodCall, TransportProtocol::Http, ); - (callback)(id, params, max_response_body_size as usize) + let response = (callback)(id, params, max_response_body_size as usize); + CallOrSubscription::Call(response) } MethodCallback::Async(callback) => { logger.on_call( @@ -161,23 +176,50 @@ pub(crate) async fn execute_call( ); let id = id.into_owned(); let params = params.into_owned(); - (callback)(id, params, conn_id, max_response_body_size as usize).await + let response = + (callback)(id, params, conn_id, max_response_body_size as usize).await; + CallOrSubscription::Call(response) } - MethodCallback::Subscription(_) | MethodCallback::Unsubscription(_) => { + MethodCallback::Subscription(callback) => { + if let Some(p) = bounded_subscriptions.acquire() { + let conn_state = + SubscriptionState { conn_id, id_provider, subscription_permit: p }; + match callback(id, params, sink.clone(), conn_state).await { + Ok(r) => CallOrSubscription::Subscription(r), + Err(id) => { + let response = MethodResponse::error( + id, + ErrorObject::from(ErrorCode::InternalError), + ); + CallOrSubscription::Call(response) + } + } + } else { + let response = MethodResponse::error( + id, + reject_too_many_subscriptions(bounded_subscriptions.max()), + ); + CallOrSubscription::Call(response) + } + } + MethodCallback::Unsubscription(callback) => { logger.on_call( name, params.clone(), - logger::MethodKind::Unknown, - TransportProtocol::Http, + logger::MethodKind::Unsubscription, + TransportProtocol::WebSocket, ); - tracing::error!("Subscriptions not supported on HTTP"); - MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)) + + // Don't adhere to any resource or subscription limits; always let unsubscribing + // happen! + let result = callback(id, params, conn_id, max_response_body_size as usize); + CallOrSubscription::Call(result) } }, }; - tx_log_from_str(&response.result, max_log_length); - logger.on_result(name, response.success, request_start, TransportProtocol::Http); + tx_log_from_str(&response.as_response().result, max_log_length); + logger.on_result(name, response.as_response().success, request_start, TransportProtocol::Http); response } @@ -198,10 +240,26 @@ pub(crate) struct HandleRequest { pub(crate) batch_requests_supported: bool, pub(crate) logger: L, pub(crate) conn: Arc, + pub(crate) bounded_subscriptions: BoundedSubscriptions, + pub(crate) method_sink: MethodSink, + pub(crate) id_provider: Arc, } -pub(crate) async fn handle_request(request: String, input: HandleRequest) -> String { - let HandleRequest { methods, max_response_body_size, max_log_length, logger, conn, .. } = input; +pub(crate) async fn handle_request( + request: String, + input: HandleRequest, +) -> Option { + let HandleRequest { + methods, + max_response_body_size, + max_log_length, + logger, + conn, + bounded_subscriptions, + method_sink, + id_provider, + .. + } = input; enum Kind { Single, @@ -223,14 +281,25 @@ pub(crate) async fn handle_request(request: String, input: HandleRequ conn_id: 0, logger: &logger, methods: &methods, + id_provider: &*id_provider, + sink: &method_sink, max_response_body_size, max_log_length, request_start, + bounded_subscriptions, }; // Single request or notification let res = if matches!(request_kind, Kind::Single) { let response = process_single_request(request.into_bytes(), call).await; - response.result + match response { + Some(CallOrSubscription::Call(response)) => Some(response.result), + Some(CallOrSubscription::Subscription(_)) => { + // subscription responses are sent directly over the sink, return a response here + // would lead to duplicate responses for the subscription response + None + } + None => None, + } } else { process_batch_request(Batch { data: request.into_bytes(), call }).await }; diff --git a/crates/rpc/ipc/src/server/mod.rs b/crates/rpc/ipc/src/server/mod.rs index b7b0ccf4b7..6de31d61da 100644 --- a/crates/rpc/ipc/src/server/mod.rs +++ b/crates/rpc/ipc/src/server/mod.rs @@ -8,7 +8,7 @@ use futures::{FutureExt, SinkExt, Stream, StreamExt}; use jsonrpsee::{ core::{Error, TEN_MB_SIZE_BYTES}, server::{logger::Logger, IdProvider, RandomIntegerIdProvider, ServerHandle}, - Methods, + BoundedSubscriptions, MethodSink, Methods, }; use std::{ future::Future, @@ -26,6 +26,8 @@ use tracing::{trace, warn}; // re-export so can be used during builder setup pub use parity_tokio_ipc::Endpoint; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; mod connection; mod future; @@ -104,6 +106,7 @@ impl IpcServer { } } + let message_buffer_capacity = self.cfg.message_buffer_capacity; let max_request_body_size = self.cfg.max_request_body_size; let max_response_body_size = self.cfg.max_response_body_size; let max_log_length = self.cfg.max_log_length; @@ -141,6 +144,9 @@ impl IpcServer { } }; + let (tx, rx) = mpsc::channel::(message_buffer_capacity as usize); + let method_sink = + MethodSink::new_with_limit(tx, max_response_body_size, max_log_length); let tower_service = TowerService { inner: ServiceData { methods: methods.clone(), @@ -153,11 +159,20 @@ impl IpcServer { conn_id: id, logger, conn: Arc::new(conn), + bounded_subscriptions: BoundedSubscriptions::new( + max_subscriptions_per_connection, + ), + method_sink, }, }; let service = self.service_builder.service(tower_service); - connections.add(Box::pin(spawn_connection(ipc, service, stop_handle.clone()))); + connections.add(Box::pin(spawn_connection( + ipc, + service, + stop_handle.clone(), + rx, + ))); id = id.wrapping_add(1); } @@ -183,7 +198,7 @@ impl std::fmt::Debug for IpcServer { } } -/// Data required by the server to handle requests. +/// Data required by the server to handle requests received via an IPC connection #[derive(Debug, Clone)] #[allow(unused)] pub(crate) struct ServiceData { @@ -209,6 +224,12 @@ pub(crate) struct ServiceData { pub(crate) logger: L, /// Handle to hold a `connection permit`. pub(crate) conn: Arc, + /// Limits the number of subscriptions for this connection + pub(crate) bounded_subscriptions: BoundedSubscriptions, + /// Sink that is used to send back responses to the connection. + /// + /// This is used for subscriptions. + pub(crate) method_sink: MethodSink, } /// JsonRPSee service compatible with `tower`. @@ -221,7 +242,12 @@ pub struct TowerService { } impl Service for TowerService { - type Response = String; + /// The response of a handled RPC call + /// + /// This is an `Option` because subscriptions and call responses are handled differently. + /// This will be `Some` for calls, and `None` for subscriptions, because the subscription + /// response will be emitted via the `method_sink`. + type Response = Option; type Error = Box; @@ -244,31 +270,45 @@ impl Service for TowerService { batch_requests_supported: true, logger: self.inner.logger.clone(), conn: self.inner.conn.clone(), + bounded_subscriptions: self.inner.bounded_subscriptions.clone(), + method_sink: self.inner.method_sink.clone(), + id_provider: self.inner.id_provider.clone(), }; Box::pin(ipc::handle_request(request, data).map(Ok)) } } -/// Spawns the connection in a new task +/// Spawns the IPC connection onto a new task async fn spawn_connection( conn: IpcConn>, mut service: S, mut stop_handle: StopHandle, + rx: mpsc::Receiver, ) where - S: Service + Send + 'static, + S: Service> + Send + 'static, S::Error: Into>, S::Future: Send, T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let task = tokio::task::spawn(async move { - tokio::pin!(conn); + let rx_item = ReceiverStream::new(rx); + tokio::pin!(conn, rx_item); loop { - let request = tokio::select! { + let item = tokio::select! { res = conn.next() => { match res { Some(Ok(request)) => { - request + // handle the RPC request + match service.call(request).await { + Ok(Some(resp)) => { + resp + }, + Ok(None) => { + continue + }, + Err(err) => err.into().to_string(), + } }, Some(Err(e)) => { tracing::warn!("Request failed: {:?}", e); @@ -279,19 +319,21 @@ async fn spawn_connection( } } } + item = rx_item.next() => { + match item { + Some(item) => item, + None => { + continue + } + } + } _ = stop_handle.shutdown() => { break } }; - // handle the RPC request - let resp = match service.call(request).await { - Ok(resp) => resp, - Err(err) => err.into().to_string(), - }; - - // send back - if let Err(err) = conn.send(resp).await { + // send item over ipc + if let Err(err) = conn.send(item).await { warn!("Failed to send response: {:?}", err); break } @@ -352,6 +394,8 @@ pub struct Settings { max_connections: u32, /// Maximum number of subscriptions per connection. max_subscriptions_per_connection: u32, + /// Number of messages that server is allowed `buffer` until backpressure kicks in. + message_buffer_capacity: u32, /// Custom tokio runtime to run the server on. tokio_runtime: Option, } @@ -364,6 +408,7 @@ impl Default for Settings { max_log_length: 4096, max_connections: 100, max_subscriptions_per_connection: 1024, + message_buffer_capacity: 1024, tokio_runtime: None, } } @@ -421,6 +466,28 @@ impl Builder { self } + /// The server enforces backpressure which means that + /// `n` messages can be buffered and if the client + /// can't keep with up the server. + /// + /// This `capacity` is applied per connection and + /// applies globally on the connection which implies + /// all JSON-RPC messages. + /// + /// For example if a subscription produces plenty of new items + /// and the client can't keep up then no new messages are handled. + /// + /// If this limit is exceeded then the server will "back-off" + /// and only accept new messages once the client reads pending messages. + /// + /// # Panics + /// + /// Panics if the buffer capacity is 0. + pub fn set_message_buffer_capacity(mut self, c: u32) -> Self { + self.settings.message_buffer_capacity = c; + self + } + /// Add a logger to the builder [`Logger`]. pub fn set_logger(self, logger: T) -> Builder { Builder { @@ -514,10 +581,61 @@ impl Builder { mod tests { use super::*; use crate::client::IpcClientBuilder; - use jsonrpsee::{core::client::ClientT, rpc_params, RpcModule}; + use futures::future::{select, Either}; + use jsonrpsee::{ + core::client::{ClientT, Subscription, SubscriptionClientT}, + rpc_params, PendingSubscriptionSink, RpcModule, SubscriptionMessage, + }; use parity_tokio_ipc::dummy_endpoint; + use tokio::sync::broadcast; + use tokio_stream::wrappers::BroadcastStream; use tracing_test::traced_test; + async fn pipe_from_stream_with_bounded_buffer( + pending: PendingSubscriptionSink, + stream: BroadcastStream, + ) -> Result<(), Box> { + let sink = pending.accept().await.unwrap(); + let closed = sink.closed(); + + futures::pin_mut!(closed, stream); + + loop { + match select(closed, stream.next()).await { + // subscription closed. + Either::Left((_, _)) => break Ok(()), + + // received new item from the stream. + Either::Right((Some(Ok(item)), c)) => { + let notif = SubscriptionMessage::from_json(&item)?; + + // NOTE: this will block until there a spot in the queue + // and you might want to do something smarter if it's + // critical that "the most recent item" must be sent when it is produced. + if sink.send(notif).await.is_err() { + break Ok(()) + } + + closed = c; + } + + // Send back back the error. + Either::Right((Some(Err(e)), _)) => break Err(e.into()), + + // Stream is closed. + Either::Right((None, _)) => break Ok(()), + } + } + } + + // Naive example that broadcasts the produced values to all active subscribers. + fn produce_items(tx: broadcast::Sender) { + for c in 1..=100 { + std::thread::sleep(std::time::Duration::from_millis(1)); + let _ = tx.send(c); + } + } + #[tokio::test] #[traced_test] async fn test_rpc_request() { @@ -533,4 +651,39 @@ mod tests { let response: String = client.request("eth_chainId", rpc_params![]).await.unwrap(); assert_eq!(response, msg); } + + #[tokio::test(flavor = "multi_thread")] + #[traced_test] + async fn test_rpc_subscription() { + let endpoint = dummy_endpoint(); + let server = Builder::default().build(&endpoint).unwrap(); + let (tx, _rx) = broadcast::channel::(16); + + let mut module = RpcModule::new(tx.clone()); + std::thread::spawn(move || produce_items(tx)); + + module + .register_subscription( + "subscribe_hello", + "s_hello", + "unsubscribe_hello", + |_, pending, tx| async move { + let rx = tx.subscribe(); + let stream = BroadcastStream::new(rx); + pipe_from_stream_with_bounded_buffer(pending, stream).await?; + Ok(()) + }, + ) + .unwrap(); + + let handle = server.start(module).await.unwrap(); + tokio::spawn(handle.stopped()); + + let client = IpcClientBuilder::default().build(endpoint).await.unwrap(); + let sub: Subscription = + client.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await.unwrap(); + + let items = sub.take(16).collect::>().await; + assert_eq!(items.len(), 16); + } }