From e589fc0b2b1c38ac88d360aa0084ffdc971729a7 Mon Sep 17 00:00:00 2001 From: parazyd Date: Tue, 21 Nov 2023 10:37:12 +0100 Subject: [PATCH] rpc: Simplify stream reading and move timeouts to outer scope. --- src/rpc/client.rs | 11 ++++-- src/rpc/common.rs | 85 +++++++++++++++-------------------------------- src/rpc/server.rs | 10 +++--- 3 files changed, 39 insertions(+), 67 deletions(-) diff --git a/src/rpc/client.rs b/src/rpc/client.rs index d974de357..e4a46490c 100644 --- a/src/rpc/client.rs +++ b/src/rpc/client.rs @@ -24,12 +24,12 @@ use tinyjson::JsonValue; use url::Url; use super::{ - common::{read_from_stream, write_to_stream, INIT_BUF_SIZE}, + common::{read_from_stream, write_to_stream, INIT_BUF_SIZE, READ_TIMEOUT}, jsonrpc::*, }; use crate::{ net::transport::{Dialer, PtStream}, - system::{StoppableTask, StoppableTaskPtr, SubscriberPtr}, + system::{io_timeout, StoppableTask, StoppableTaskPtr, SubscriberPtr}, Error, Result, }; @@ -103,7 +103,12 @@ impl RpcClient { let request = JsonResult::Request(request); write_to_stream(&mut writer, &request).await?; - let _ = read_from_stream(&mut reader, &mut buf, with_timeout).await?; + if with_timeout { + let _ = io_timeout(READ_TIMEOUT, read_from_stream(&mut reader, &mut buf)).await?; + } else { + let _ = read_from_stream(&mut reader, &mut buf).await?; + } + let val: JsonValue = String::from_utf8(buf)?.parse()?; let rep = JsonResult::try_from_value(&val)?; rep_send.send(rep).await?; diff --git a/src/rpc/common.rs b/src/rpc/common.rs index 91d186988..d65ed9571 100644 --- a/src/rpc/common.rs +++ b/src/rpc/common.rs @@ -16,12 +16,12 @@ * along with this program. If not, see . */ -use std::time::Duration; +use std::{io, time::Duration}; use smol::io::{AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use super::jsonrpc::*; -use crate::{error::RpcError, net::transport::PtStream, system::io_timeout, Result}; +use crate::net::transport::PtStream; pub(super) const INIT_BUF_SIZE: usize = 4096; // 4K pub(super) const MAX_BUF_SIZE: usize = 1024 * 8192; // 8M @@ -32,8 +32,7 @@ pub(super) const READ_TIMEOUT: Duration = Duration::from_secs(30); pub(super) async fn read_from_stream( reader: &mut BufReader>>, buf: &mut Vec, - with_timeout: bool, -) -> Result { +) -> io::Result { let mut total_read = 0; // Intermediate buffer we use to read byte-by-byte. @@ -42,61 +41,29 @@ pub(super) async fn read_from_stream( while total_read < MAX_BUF_SIZE { buf.resize(total_read + INIT_BUF_SIZE, 0); - // Lame we have to duplicate this code, but it is what it is. - if with_timeout { - match io_timeout(READ_TIMEOUT, reader.read(&mut tmpbuf)).await { - Ok(0) if total_read == 0 => { - return Err( - RpcError::ConnectionClosed("Connection closed cleanly".to_string()).into() - ) - } - Ok(0) => break, // Finished reading - Ok(_) => { - // When we reach '\n', pop a possible '\r' from the buffer and bail. - if tmpbuf[0] == b'\n' { - if buf[total_read - 1] == b'\r' { - buf.pop(); - total_read -= 1; - } - break + match reader.read(&mut tmpbuf).await { + Ok(0) if total_read == 0 => return Err(io::ErrorKind::ConnectionAborted.into()), + Ok(0) => break, // Finished reading + Ok(_) => { + // When we reach '\n', pop a possible '\r' from the buffer and bail. + if tmpbuf[0] == b'\n' { + if buf[total_read - 1] == b'\r' { + buf.pop(); + total_read -= 1; } - - // Copy the read byte to the destination buffer. - buf[total_read] = tmpbuf[0]; - total_read += 1; + break } - Err(e) => return Err(RpcError::IoError(e.kind()).into()), + // Copy the read byte to the destination buffer. + buf[total_read] = tmpbuf[0]; + total_read += 1; } - } else { - match reader.read(&mut tmpbuf).await { - Ok(0) if total_read == 0 => { - return Err( - RpcError::ConnectionClosed("Connection closed cleanly".to_string()).into() - ) - } - Ok(0) => break, // Finished reading - Ok(_) => { - // When we reach '\n', pop a possible '\r' from the buffer and bail. - if tmpbuf[0] == b'\n' { - if buf[total_read - 1] == b'\r' { - buf.pop(); - total_read -= 1; - } - break - } - // Copy the read byte to the destination buffer. - buf[total_read] = tmpbuf[0]; - total_read += 1; - } - - Err(e) => return Err(RpcError::IoError(e.kind()).into()), - } + Err(e) => return Err(e), } } - // Trunacate buffer to actual data size + // Truncate buffer to actual data size buf.truncate(total_read); Ok(total_read) } @@ -105,21 +72,21 @@ pub(super) async fn read_from_stream( pub(super) async fn write_to_stream( writer: &mut WriteHalf>, object: &JsonResult, -) -> Result<()> { +) -> io::Result<()> { let object_str = match object { - JsonResult::Notification(v) => v.stringify()?, - JsonResult::Response(v) => v.stringify()?, - JsonResult::Error(v) => v.stringify()?, - JsonResult::Request(v) => v.stringify()?, + JsonResult::Notification(v) => v.stringify().unwrap(), + JsonResult::Response(v) => v.stringify().unwrap(), + JsonResult::Error(v) => v.stringify().unwrap(), + JsonResult::Request(v) => v.stringify().unwrap(), _ => unreachable!(), }; // As we're a line-based protocol, we append CRLF to the end of the JSON string. for i in [object_str.as_bytes(), &[b'\r', b'\n']] { - if let Err(e) = writer.write_all(i).await { - return Err(e.into()) - } + writer.write_all(i).await? } + writer.flush().await?; + Ok(()) } diff --git a/src/rpc/server.rs b/src/rpc/server.rs index d988adb77..f67c6a8ad 100644 --- a/src/rpc/server.rs +++ b/src/rpc/server.rs @@ -103,7 +103,7 @@ pub async fn accept( let mut buf = Vec::with_capacity(INIT_BUF_SIZE); let mut reader_lock = reader.lock().await; - let _ = read_from_stream(&mut reader_lock, &mut buf, false).await?; + let _ = read_from_stream(&mut reader_lock, &mut buf).await?; drop(reader_lock); let line = match String::from_utf8(buf) { @@ -165,13 +165,13 @@ pub async fn accept( let notification = subscription.receive().await; // Push notification - debug!(target: "rpc::server", "{} <-- {}", addr_, notification.stringify()?); + debug!(target: "rpc::server", "{} <-- {}", addr_, notification.stringify().unwrap()); let notification = JsonResult::Notification(notification); let mut writer_lock = writer_.lock().await; if let Err(e) = write_to_stream(&mut writer_lock, ¬ification).await { subscription.unsubscribe().await; - return Err(e) + return Err(e.into()) } drop(writer_lock); } @@ -215,14 +215,14 @@ pub async fn accept( let notification = subscription.receive().await; // Push notification - debug!(target: "rpc::server", "{} <-- {}", addr_, notification.stringify()?); + debug!(target: "rpc::server", "{} <-- {}", addr_, notification.stringify().unwrap()); let notification = JsonResult::Notification(notification); let mut writer_lock = writer_.lock().await; if let Err(e) = write_to_stream(&mut writer_lock, ¬ification).await { subscription.unsubscribe().await; drop(writer_lock); - return Err(e) + return Err(e.into()) } drop(writer_lock); }