diff --git a/src/rpc/jsonrpc.rs b/src/rpc/jsonrpc.rs index 018672011..879026be6 100644 --- a/src/rpc/jsonrpc.rs +++ b/src/rpc/jsonrpc.rs @@ -3,7 +3,7 @@ use std::{env, str, time::Duration}; use async_executor::Executor; use async_std::io::timeout; -use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::{select, AsyncReadExt, AsyncWriteExt, FutureExt}; use log::error; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -31,6 +31,7 @@ pub enum ErrorCode { InvalidTokenIdParam, InvalidAddressParam, InvalidSymbolParam, + InvalidId, ServerError(i64), } @@ -51,6 +52,7 @@ impl ErrorCode { ErrorCode::InvalidTokenIdParam => -32012, ErrorCode::InvalidAddressParam => -32013, ErrorCode::InvalidSymbolParam => -32014, + ErrorCode::InvalidId => -32030, ErrorCode::ServerError(c) => c, } } @@ -71,6 +73,7 @@ impl ErrorCode { ErrorCode::InvalidTokenIdParam => "Invalid token id param", ErrorCode::InvalidAddressParam => "Invalid address param", ErrorCode::InvalidSymbolParam => "Invalid symbol param", + ErrorCode::InvalidId => "Invalid Id", ErrorCode::ServerError(_) => "Server error", }; desc.to_string() @@ -163,8 +166,9 @@ pub fn notification(m: Value, p: Value) -> JsonNotification { async fn reqrep_loop( mut stream: T, - data_receiver: async_channel::Receiver, result_sender: async_channel::Sender, + data_receiver: async_channel::Receiver, + stop_receiver: async_channel::Receiver<()>, ) -> Result<()> { // If we don't get a reply after 30 seconds, we'll fail. let read_timeout = Duration::from_secs(30); @@ -172,30 +176,42 @@ async fn reqrep_loop( loop { let mut buf = [0; 8192]; - let data = data_receiver.recv().await?; - let data_str = serde_json::to_string(&data)?; + select! { + data = data_receiver.recv().fuse() => { + let data_str = serde_json::to_string(&data?)?; - stream.write_all(data_str.as_bytes()).await?; + stream.write_all(data_str.as_bytes()).await?; - let bytes_read = timeout(read_timeout, async { stream.read(&mut buf[..]).await }).await?; + let bytes_read = timeout(read_timeout, async { stream.read(&mut buf[..]).await }).await?; - let reply: JsonResult = serde_json::from_slice(&buf[0..bytes_read])?; + let reply: JsonResult = serde_json::from_slice(&buf[0..bytes_read])?; - result_sender.send(reply).await?; + result_sender.send(reply).await?; + } + + _ = stop_receiver.recv().fuse() => break + } } + + Ok(()) } pub async fn open_channels( uri: &Url, executor: Arc>, -) -> Result<(async_channel::Sender, async_channel::Receiver)> { +) -> Result<( + async_channel::Sender, + async_channel::Receiver, + async_channel::Sender<()>, +)> { let (data_sender, data_receiver) = async_channel::unbounded(); let (result_sender, result_receiver) = async_channel::unbounded(); + let (stop_sender, stop_receiver) = async_channel::unbounded(); let transport_name = TransportName::try_from(uri.clone())?; - macro_rules! hanlde_stream { - ($stream:expr, $transport:expr, $upgrade:expr) => { + macro_rules! reqrep { + ($stream:expr, $transport:expr, $upgrade:expr) => {{ if let Err(err) = $stream { error!("RPC Setup for {} failed: {}", uri, err); return Err(Error::ConnectFailed) @@ -212,15 +228,19 @@ pub async fn open_channels( match $upgrade { None => { - executor.spawn(reqrep_loop(stream, data_receiver, result_sender)).detach(); + executor + .spawn(reqrep_loop(stream, result_sender, data_receiver, stop_receiver)) + .detach(); } Some(u) if u == "tls" => { let stream = $transport.upgrade_dialer(stream)?.await?; - executor.spawn(reqrep_loop(stream, data_receiver, result_sender)).detach(); + executor + .spawn(reqrep_loop(stream, result_sender, data_receiver, stop_receiver)) + .detach(); } Some(u) => return Err(Error::UnsupportedTransportUpgrade(u)), } - }; + }}; } match transport_name { @@ -228,7 +248,7 @@ pub async fn open_channels( let transport = TcpTransport::new(None, 1024); let stream = transport.dial(uri.clone()); - hanlde_stream!(stream, transport, upgrade); + reqrep!(stream, transport, upgrade); } TransportName::Tor(upgrade) => { let socks5_url = Url::parse( @@ -240,7 +260,7 @@ pub async fn open_channels( let stream = transport.clone().dial(uri.clone()); - hanlde_stream!(stream, transport, upgrade); + reqrep!(stream, transport, upgrade); } TransportName::Unix => { let transport = UnixTransport::new(); @@ -252,12 +272,14 @@ pub async fn open_channels( return Err(Error::ConnectFailed) } - executor.spawn(reqrep_loop(stream?, data_receiver, result_sender)).detach(); + executor + .spawn(reqrep_loop(stream?, result_sender, data_receiver, stop_receiver)) + .detach(); } _ => unimplemented!(), } - Ok((data_sender, result_receiver)) + Ok((data_sender, result_receiver, stop_sender)) } pub async fn send_request(uri: &Url, data: Value) -> Result { @@ -265,31 +287,39 @@ pub async fn send_request(uri: &Url, data: Value) -> Result { let transport_name = TransportName::try_from(uri.clone())?; + macro_rules! reply { + ($stream:expr, $transport:expr, $upgrade:expr) => {{ + if let Err(err) = $stream { + error!("RPC Setup for {} failed: {}", uri, err); + return Err(Error::ConnectFailed) + } + + let stream = $stream?.await; + + if let Err(err) = stream { + error!("RPC Connection to {} failed: {}", uri, err); + return Err(Error::ConnectFailed) + } + + let stream = stream?; + + match $upgrade { + None => get_reply(stream, data_str).await, + Some(u) if u == "tls" => { + let stream = $transport.upgrade_dialer(stream)?.await?; + get_reply(stream, data_str).await + } + Some(u) => Err(Error::UnsupportedTransportUpgrade(u)), + } + }}; + } + match transport_name { TransportName::Tcp(upgrade) => { let transport = TcpTransport::new(None, 1024); let stream = transport.dial(uri.clone()); - if let Err(err) = stream { - error!("RPC Setup for {} failed: {}", uri, err); - return Err(Error::ConnectFailed) - } - - let stream = stream?.await; - - if let Err(err) = stream { - error!("RPC Connection to {} failed: {}", uri, err); - return Err(Error::ConnectFailed) - } - - match upgrade { - None => get_reply(&mut stream?, data_str).await, - Some(u) if u == "tls" => { - let mut stream = transport.upgrade_dialer(stream?)?.await?; - get_reply(&mut stream, data_str).await - } - Some(u) => Err(Error::UnsupportedTransportUpgrade(u)), - } + reply!(stream, transport, upgrade) } TransportName::Tor(upgrade) => { let socks5_url = Url::parse( @@ -301,26 +331,7 @@ pub async fn send_request(uri: &Url, data: Value) -> Result { let stream = transport.clone().dial(uri.clone()); - if let Err(err) = stream { - error!("RPC Setup for {} failed: {}", uri, err); - return Err(Error::ConnectFailed) - } - - let stream = stream?.await; - - if let Err(err) = stream { - error!("RPC Connection to {} failed: {}", uri, err); - return Err(Error::ConnectFailed) - } - - match upgrade { - None => get_reply(&mut stream?, data_str).await, - Some(u) if u == "tls" => { - let mut stream = transport.upgrade_dialer(stream?)?.await?; - get_reply(&mut stream, data_str).await - } - Some(u) => Err(Error::UnsupportedTransportUpgrade(u)), - } + reply!(stream, transport, upgrade) } TransportName::Unix => { let transport = UnixTransport::new(); @@ -332,13 +343,13 @@ pub async fn send_request(uri: &Url, data: Value) -> Result { return Err(Error::ConnectFailed) } - get_reply(&mut stream?, data_str).await + get_reply(stream?, data_str).await } _ => unimplemented!(), } } -async fn get_reply(stream: &mut T, data_str: String) -> Result { +async fn get_reply(mut stream: T, data_str: String) -> Result { // If we don't get a reply after 30 seconds, we'll fail. let read_timeout = Duration::from_secs(30);