Feat/ws relay (#43)

* feat: websocket relay

* support both TCP proxy and ws clients

* clippy

* fix: prevent sending after ws client closed (#38)

* clean up cargo.toml

* fix multiple workspace error

* fix lints

---------

Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>
This commit is contained in:
dan
2024-10-18 14:49:28 +02:00
committed by GitHub
parent 6f629684d0
commit 0040a00b44
5 changed files with 277 additions and 18 deletions

View File

@@ -1,34 +1,42 @@
[workspace]
members = ["utils", "utils-aio", "spansy", "serio", "uid-mux", "utils/fuzz"]
members = [
"serio",
"spansy",
"uid-mux",
"utils",
"utils-aio",
"utils/fuzz",
"websocket-relay"
]
[workspace.dependencies]
serio = { path = "serio" }
spansy = { path = "spansy" }
tlsn-utils = { path = "utils" }
tlsn-utils-aio = { path = "utils-aio" }
spansy = { path = "spansy" }
serio = { path = "serio" }
uid-mux = { path = "uid-mux" }
rand = "0.8"
thiserror = "1"
async-std = "1"
async-trait = "0.1"
prost = "0.9"
async-tungstenite = "0.16"
bincode = "1.3"
bytes = "1"
cfg-if = "1"
futures = "0.3"
futures-sink = "0.3"
futures-channel = "0.3"
futures-core = "0.3"
futures-io = "0.3"
futures-channel = "0.3"
futures-sink = "0.3"
futures-util = "0.3"
tokio-util = "0.7"
tokio-serde = "0.8"
tokio = "1.23"
async-tungstenite = "0.16"
pin-project-lite = "0.2"
prost = "0.9"
prost-build = "0.9"
bytes = "1"
async-std = "1"
rand = "0.8"
rayon = "1"
serde = "1"
cfg-if = "1"
bincode = "1.3"
pin-project-lite = "0.2"
thiserror = "1"
tokio = "1.23"
tokio-serde = "0.8"
tokio-util = "0.7"
tracing = "0.1"
tracing-subscriber = "0.3"
tracing-subscriber = "0.3"

View File

@@ -27,6 +27,7 @@ mod tests {
use futures_util::StreamExt;
#[derive(Debug)]
#[allow(dead_code)]
enum Msg {
Foo(u8),
Bar(u8),

View File

@@ -0,0 +1,18 @@
[package]
name = "websocket-relay"
version = "0.1.0"
edition = "2021"
authors = ["TLSNotary Contributors"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/tlsnotary/tlsn-utils"
description = """A relay for websocket clients."""
[dependencies]
anyhow = "1"
form_urlencoded = "1.2"
futures = { workspace = true }
once_cell = "1.19"
tokio = { workspace = true, features = ["full"] }
tokio-tungstenite = { version = "0.23", features = ["url"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }

203
websocket-relay/src/lib.rs Normal file
View File

@@ -0,0 +1,203 @@
use std::{
collections::HashMap,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
};
use anyhow::{anyhow, Result};
use futures::{SinkExt, StreamExt as _};
use once_cell::sync::Lazy;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use tokio_tungstenite::{
accept_hdr_async,
tungstenite::{http::Request, Message},
WebSocketStream,
};
use tracing::{debug, info, instrument};
#[derive(Debug, Default)]
struct State {
waiting: HashMap<ConnectionId, WebSocketStream<TcpStream>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ConnectionId(String);
static STATE: Lazy<Arc<Mutex<State>>> = Lazy::new(Default::default);
enum Mode {
/// Acts a proxy between two websocket clients.
Ws {
id: ConnectionId,
ws: WebSocketStream<TcpStream>,
},
/// Acts as a proxy between a websocket client and a TCP server.
Tcp {
addr: String,
ws: WebSocketStream<TcpStream>,
},
}
/// Runs the websocket relay server with the given TCP listener.
#[instrument]
pub async fn run(listener: TcpListener) -> Result<()> {
loop {
let (socket, addr) = listener.accept().await?;
info!("accepted connection from: {}", addr);
tokio::spawn(handle_connection(addr, socket));
}
}
#[instrument(skip(io), err)]
async fn handle_connection(addr: SocketAddr, io: TcpStream) -> Result<()> {
match accept_ws(io).await? {
Mode::Ws { id, ws } => {
tokio::spawn(handle_ws(id, ws));
}
Mode::Tcp { addr, ws } => {
tokio::spawn(handle_tcp(addr, ws));
}
}
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn accept_ws(io: TcpStream) -> Result<Mode> {
let mut uri = None;
let mut ws = accept_hdr_async(io, |req: &Request<()>, res| {
uri = Some(req.uri().clone());
Ok(res)
})
.await?;
let uri = uri.expect("uri should be set");
let query = uri
.query()
.ok_or_else(|| anyhow!("query string not provided"))?;
let mut params = form_urlencoded::parse(query.as_bytes())
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect::<HashMap<String, String>>();
match uri.path() {
"/tcp" => {
let addr = params
.remove("addr")
.ok_or_else(|| anyhow!("addr query parameter not provided"))?;
return Ok(Mode::Tcp { addr, ws });
}
"/ws" => {
let id = params
.remove("id")
.ok_or_else(|| anyhow!("id query parameter not provided"))?;
return Ok(Mode::Ws {
id: ConnectionId(id),
ws,
});
}
_ => {
ws.close(None).await?;
return Err(anyhow!("invalid path: {:?}", uri.path()));
}
}
}
/// Relays messages between two websocket clients.
#[instrument(level = "debug", skip(ws), err)]
async fn handle_ws(id: ConnectionId, ws: WebSocketStream<TcpStream>) -> Result<()> {
let peer = {
let mut state = STATE.lock().unwrap();
if let Some(peer) = state.waiting.remove(&id) {
peer
} else {
state.waiting.insert(id.clone(), ws);
debug!("connection waiting");
return Ok(());
}
};
debug!("started");
let (left_sink, left_stream) = ws.split();
let (right_sink, right_stream) = peer.split();
tokio::try_join!(
left_stream.forward(right_sink),
right_stream.forward(left_sink),
)?;
debug!("connection closed cleanly");
Ok(())
}
/// Relays data between a websocket client and a TCP server.
#[instrument(level = "debug", skip(ws), err)]
async fn handle_tcp(addr: String, ws: WebSocketStream<TcpStream>) -> Result<()> {
let mut tcp = TcpStream::connect(addr).await?;
let (mut sink, mut stream) = ws.split();
let (mut rx, mut tx) = tcp.split();
let is_client_closed = AtomicBool::new(false);
let fut_tx = async {
while let Some(msg) = stream.next().await.transpose()? {
let data = match msg {
Message::Binary(data) => data,
Message::Close(_) => {
break;
}
_ => {
return Err(anyhow!("websocket client sent non-binary message"));
}
};
tx.write_all(&data).await?;
}
debug!("websocket client closed");
is_client_closed.store(true, Ordering::Relaxed);
tx.shutdown().await?;
Ok(())
};
let fut_rx = async {
// 16KB buffer
let mut buf = [0; 16 * 1024];
loop {
let n = rx.read(&mut buf).await?;
if n == 0 {
debug!("tcp server closed");
sink.close().await?;
return Ok(());
}
// Only send to client if it hasn't closed.
if !is_client_closed.load(Ordering::Relaxed) {
sink.send(Message::Binary(buf[..n].to_vec())).await?;
}
}
};
tokio::try_join!(fut_tx, fut_rx)?;
Ok(())
}

View File

@@ -0,0 +1,29 @@
use std::{env, net::IpAddr};
use anyhow::{Context, Result};
use tokio::net::TcpListener;
use tracing::info;
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
let port: u16 = env::var("PROXY_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(8080);
let addr: IpAddr = env::var("PROXY_IP")
.map(|addr| addr.parse().expect("should be valid IP address"))
.unwrap_or(IpAddr::V4("127.0.0.1".parse().unwrap()));
let listener = TcpListener::bind((addr, port))
.await
.context("failed to bind to address")?;
info!("listening on: {}", listener.local_addr()?);
websocket_relay::run(listener).await?;
Ok(())
}