rpc: remove boilerplate code in jsonrpc.rs and use macro & clean up

This commit is contained in:
ghassmo
2022-05-13 07:21:08 +03:00
parent 19f2c79ea7
commit 9f58f0aa21

View File

@@ -3,7 +3,7 @@ use std::{env, str, time::Duration};
use async_executor::Executor;
use async_std::io::timeout;
use futures::{AsyncReadExt, AsyncWriteExt};
use futures::{select, AsyncReadExt, AsyncWriteExt, FutureExt};
use log::error;
use rand::Rng;
use serde::{Deserialize, Serialize};
@@ -31,6 +31,7 @@ pub enum ErrorCode {
InvalidTokenIdParam,
InvalidAddressParam,
InvalidSymbolParam,
InvalidId,
ServerError(i64),
}
@@ -51,6 +52,7 @@ impl ErrorCode {
ErrorCode::InvalidTokenIdParam => -32012,
ErrorCode::InvalidAddressParam => -32013,
ErrorCode::InvalidSymbolParam => -32014,
ErrorCode::InvalidId => -32030,
ErrorCode::ServerError(c) => c,
}
}
@@ -71,6 +73,7 @@ impl ErrorCode {
ErrorCode::InvalidTokenIdParam => "Invalid token id param",
ErrorCode::InvalidAddressParam => "Invalid address param",
ErrorCode::InvalidSymbolParam => "Invalid symbol param",
ErrorCode::InvalidId => "Invalid Id",
ErrorCode::ServerError(_) => "Server error",
};
desc.to_string()
@@ -163,8 +166,9 @@ pub fn notification(m: Value, p: Value) -> JsonNotification {
async fn reqrep_loop<T: TransportStream>(
mut stream: T,
data_receiver: async_channel::Receiver<Value>,
result_sender: async_channel::Sender<JsonResult>,
data_receiver: async_channel::Receiver<Value>,
stop_receiver: async_channel::Receiver<()>,
) -> Result<()> {
// If we don't get a reply after 30 seconds, we'll fail.
let read_timeout = Duration::from_secs(30);
@@ -172,30 +176,42 @@ async fn reqrep_loop<T: TransportStream>(
loop {
let mut buf = [0; 8192];
let data = data_receiver.recv().await?;
let data_str = serde_json::to_string(&data)?;
select! {
data = data_receiver.recv().fuse() => {
let data_str = serde_json::to_string(&data?)?;
stream.write_all(data_str.as_bytes()).await?;
stream.write_all(data_str.as_bytes()).await?;
let bytes_read = timeout(read_timeout, async { stream.read(&mut buf[..]).await }).await?;
let bytes_read = timeout(read_timeout, async { stream.read(&mut buf[..]).await }).await?;
let reply: JsonResult = serde_json::from_slice(&buf[0..bytes_read])?;
let reply: JsonResult = serde_json::from_slice(&buf[0..bytes_read])?;
result_sender.send(reply).await?;
result_sender.send(reply).await?;
}
_ = stop_receiver.recv().fuse() => break
}
}
Ok(())
}
pub async fn open_channels(
uri: &Url,
executor: Arc<Executor<'_>>,
) -> Result<(async_channel::Sender<Value>, async_channel::Receiver<JsonResult>)> {
) -> Result<(
async_channel::Sender<Value>,
async_channel::Receiver<JsonResult>,
async_channel::Sender<()>,
)> {
let (data_sender, data_receiver) = async_channel::unbounded();
let (result_sender, result_receiver) = async_channel::unbounded();
let (stop_sender, stop_receiver) = async_channel::unbounded();
let transport_name = TransportName::try_from(uri.clone())?;
macro_rules! hanlde_stream {
($stream:expr, $transport:expr, $upgrade:expr) => {
macro_rules! reqrep {
($stream:expr, $transport:expr, $upgrade:expr) => {{
if let Err(err) = $stream {
error!("RPC Setup for {} failed: {}", uri, err);
return Err(Error::ConnectFailed)
@@ -212,15 +228,19 @@ pub async fn open_channels(
match $upgrade {
None => {
executor.spawn(reqrep_loop(stream, data_receiver, result_sender)).detach();
executor
.spawn(reqrep_loop(stream, result_sender, data_receiver, stop_receiver))
.detach();
}
Some(u) if u == "tls" => {
let stream = $transport.upgrade_dialer(stream)?.await?;
executor.spawn(reqrep_loop(stream, data_receiver, result_sender)).detach();
executor
.spawn(reqrep_loop(stream, result_sender, data_receiver, stop_receiver))
.detach();
}
Some(u) => return Err(Error::UnsupportedTransportUpgrade(u)),
}
};
}};
}
match transport_name {
@@ -228,7 +248,7 @@ pub async fn open_channels(
let transport = TcpTransport::new(None, 1024);
let stream = transport.dial(uri.clone());
hanlde_stream!(stream, transport, upgrade);
reqrep!(stream, transport, upgrade);
}
TransportName::Tor(upgrade) => {
let socks5_url = Url::parse(
@@ -240,7 +260,7 @@ pub async fn open_channels(
let stream = transport.clone().dial(uri.clone());
hanlde_stream!(stream, transport, upgrade);
reqrep!(stream, transport, upgrade);
}
TransportName::Unix => {
let transport = UnixTransport::new();
@@ -252,12 +272,14 @@ pub async fn open_channels(
return Err(Error::ConnectFailed)
}
executor.spawn(reqrep_loop(stream?, data_receiver, result_sender)).detach();
executor
.spawn(reqrep_loop(stream?, result_sender, data_receiver, stop_receiver))
.detach();
}
_ => unimplemented!(),
}
Ok((data_sender, result_receiver))
Ok((data_sender, result_receiver, stop_sender))
}
pub async fn send_request(uri: &Url, data: Value) -> Result<JsonResult> {
@@ -265,31 +287,39 @@ pub async fn send_request(uri: &Url, data: Value) -> Result<JsonResult> {
let transport_name = TransportName::try_from(uri.clone())?;
macro_rules! reply {
($stream:expr, $transport:expr, $upgrade:expr) => {{
if let Err(err) = $stream {
error!("RPC Setup for {} failed: {}", uri, err);
return Err(Error::ConnectFailed)
}
let stream = $stream?.await;
if let Err(err) = stream {
error!("RPC Connection to {} failed: {}", uri, err);
return Err(Error::ConnectFailed)
}
let stream = stream?;
match $upgrade {
None => get_reply(stream, data_str).await,
Some(u) if u == "tls" => {
let stream = $transport.upgrade_dialer(stream)?.await?;
get_reply(stream, data_str).await
}
Some(u) => Err(Error::UnsupportedTransportUpgrade(u)),
}
}};
}
match transport_name {
TransportName::Tcp(upgrade) => {
let transport = TcpTransport::new(None, 1024);
let stream = transport.dial(uri.clone());
if let Err(err) = stream {
error!("RPC Setup for {} failed: {}", uri, err);
return Err(Error::ConnectFailed)
}
let stream = stream?.await;
if let Err(err) = stream {
error!("RPC Connection to {} failed: {}", uri, err);
return Err(Error::ConnectFailed)
}
match upgrade {
None => get_reply(&mut stream?, data_str).await,
Some(u) if u == "tls" => {
let mut stream = transport.upgrade_dialer(stream?)?.await?;
get_reply(&mut stream, data_str).await
}
Some(u) => Err(Error::UnsupportedTransportUpgrade(u)),
}
reply!(stream, transport, upgrade)
}
TransportName::Tor(upgrade) => {
let socks5_url = Url::parse(
@@ -301,26 +331,7 @@ pub async fn send_request(uri: &Url, data: Value) -> Result<JsonResult> {
let stream = transport.clone().dial(uri.clone());
if let Err(err) = stream {
error!("RPC Setup for {} failed: {}", uri, err);
return Err(Error::ConnectFailed)
}
let stream = stream?.await;
if let Err(err) = stream {
error!("RPC Connection to {} failed: {}", uri, err);
return Err(Error::ConnectFailed)
}
match upgrade {
None => get_reply(&mut stream?, data_str).await,
Some(u) if u == "tls" => {
let mut stream = transport.upgrade_dialer(stream?)?.await?;
get_reply(&mut stream, data_str).await
}
Some(u) => Err(Error::UnsupportedTransportUpgrade(u)),
}
reply!(stream, transport, upgrade)
}
TransportName::Unix => {
let transport = UnixTransport::new();
@@ -332,13 +343,13 @@ pub async fn send_request(uri: &Url, data: Value) -> Result<JsonResult> {
return Err(Error::ConnectFailed)
}
get_reply(&mut stream?, data_str).await
get_reply(stream?, data_str).await
}
_ => unimplemented!(),
}
}
async fn get_reply<T: TransportStream>(stream: &mut T, data_str: String) -> Result<JsonResult> {
async fn get_reply<T: TransportStream>(mut stream: T, data_str: String) -> Result<JsonResult> {
// If we don't get a reply after 30 seconds, we'll fail.
let read_timeout = Duration::from_secs(30);