mirror of
https://github.com/tlsnotary/tlsn-utils.git
synced 2026-01-09 20:57:56 -05:00
Feat/ws relay (#43)
* feat: websocket relay * support both TCP proxy and ws clients * clippy * fix: prevent sending after ws client closed (#38) * clean up cargo.toml * fix multiple workspace error * fix lints --------- Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>
This commit is contained in:
44
Cargo.toml
44
Cargo.toml
@@ -1,34 +1,42 @@
|
||||
[workspace]
|
||||
members = ["utils", "utils-aio", "spansy", "serio", "uid-mux", "utils/fuzz"]
|
||||
members = [
|
||||
"serio",
|
||||
"spansy",
|
||||
"uid-mux",
|
||||
"utils",
|
||||
"utils-aio",
|
||||
"utils/fuzz",
|
||||
"websocket-relay"
|
||||
]
|
||||
|
||||
[workspace.dependencies]
|
||||
serio = { path = "serio" }
|
||||
spansy = { path = "spansy" }
|
||||
tlsn-utils = { path = "utils" }
|
||||
tlsn-utils-aio = { path = "utils-aio" }
|
||||
spansy = { path = "spansy" }
|
||||
serio = { path = "serio" }
|
||||
uid-mux = { path = "uid-mux" }
|
||||
|
||||
rand = "0.8"
|
||||
thiserror = "1"
|
||||
async-std = "1"
|
||||
async-trait = "0.1"
|
||||
prost = "0.9"
|
||||
async-tungstenite = "0.16"
|
||||
bincode = "1.3"
|
||||
bytes = "1"
|
||||
cfg-if = "1"
|
||||
futures = "0.3"
|
||||
futures-sink = "0.3"
|
||||
futures-channel = "0.3"
|
||||
futures-core = "0.3"
|
||||
futures-io = "0.3"
|
||||
futures-channel = "0.3"
|
||||
futures-sink = "0.3"
|
||||
futures-util = "0.3"
|
||||
tokio-util = "0.7"
|
||||
tokio-serde = "0.8"
|
||||
tokio = "1.23"
|
||||
async-tungstenite = "0.16"
|
||||
pin-project-lite = "0.2"
|
||||
prost = "0.9"
|
||||
prost-build = "0.9"
|
||||
bytes = "1"
|
||||
async-std = "1"
|
||||
rand = "0.8"
|
||||
rayon = "1"
|
||||
serde = "1"
|
||||
cfg-if = "1"
|
||||
bincode = "1.3"
|
||||
pin-project-lite = "0.2"
|
||||
thiserror = "1"
|
||||
tokio = "1.23"
|
||||
tokio-serde = "0.8"
|
||||
tokio-util = "0.7"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
tracing-subscriber = "0.3"
|
||||
@@ -27,6 +27,7 @@ mod tests {
|
||||
use futures_util::StreamExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
enum Msg {
|
||||
Foo(u8),
|
||||
Bar(u8),
|
||||
|
||||
18
websocket-relay/Cargo.toml
Normal file
18
websocket-relay/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "websocket-relay"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["TLSNotary Contributors"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/tlsnotary/tlsn-utils"
|
||||
description = """A relay for websocket clients."""
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
form_urlencoded = "1.2"
|
||||
futures = { workspace = true }
|
||||
once_cell = "1.19"
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tokio-tungstenite = { version = "0.23", features = ["url"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
203
websocket-relay/src/lib.rs
Normal file
203
websocket-relay/src/lib.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{SinkExt, StreamExt as _};
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{TcpListener, TcpStream},
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
accept_hdr_async,
|
||||
tungstenite::{http::Request, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tracing::{debug, info, instrument};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct State {
|
||||
waiting: HashMap<ConnectionId, WebSocketStream<TcpStream>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct ConnectionId(String);
|
||||
|
||||
static STATE: Lazy<Arc<Mutex<State>>> = Lazy::new(Default::default);
|
||||
|
||||
enum Mode {
|
||||
/// Acts a proxy between two websocket clients.
|
||||
Ws {
|
||||
id: ConnectionId,
|
||||
ws: WebSocketStream<TcpStream>,
|
||||
},
|
||||
/// Acts as a proxy between a websocket client and a TCP server.
|
||||
Tcp {
|
||||
addr: String,
|
||||
ws: WebSocketStream<TcpStream>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Runs the websocket relay server with the given TCP listener.
|
||||
#[instrument]
|
||||
pub async fn run(listener: TcpListener) -> Result<()> {
|
||||
loop {
|
||||
let (socket, addr) = listener.accept().await?;
|
||||
info!("accepted connection from: {}", addr);
|
||||
|
||||
tokio::spawn(handle_connection(addr, socket));
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(io), err)]
|
||||
async fn handle_connection(addr: SocketAddr, io: TcpStream) -> Result<()> {
|
||||
match accept_ws(io).await? {
|
||||
Mode::Ws { id, ws } => {
|
||||
tokio::spawn(handle_ws(id, ws));
|
||||
}
|
||||
Mode::Tcp { addr, ws } => {
|
||||
tokio::spawn(handle_tcp(addr, ws));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn accept_ws(io: TcpStream) -> Result<Mode> {
|
||||
let mut uri = None;
|
||||
|
||||
let mut ws = accept_hdr_async(io, |req: &Request<()>, res| {
|
||||
uri = Some(req.uri().clone());
|
||||
|
||||
Ok(res)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let uri = uri.expect("uri should be set");
|
||||
let query = uri
|
||||
.query()
|
||||
.ok_or_else(|| anyhow!("query string not provided"))?;
|
||||
let mut params = form_urlencoded::parse(query.as_bytes())
|
||||
.map(|(k, v)| (k.into_owned(), v.into_owned()))
|
||||
.collect::<HashMap<String, String>>();
|
||||
|
||||
match uri.path() {
|
||||
"/tcp" => {
|
||||
let addr = params
|
||||
.remove("addr")
|
||||
.ok_or_else(|| anyhow!("addr query parameter not provided"))?;
|
||||
|
||||
return Ok(Mode::Tcp { addr, ws });
|
||||
}
|
||||
"/ws" => {
|
||||
let id = params
|
||||
.remove("id")
|
||||
.ok_or_else(|| anyhow!("id query parameter not provided"))?;
|
||||
|
||||
return Ok(Mode::Ws {
|
||||
id: ConnectionId(id),
|
||||
ws,
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
ws.close(None).await?;
|
||||
|
||||
return Err(anyhow!("invalid path: {:?}", uri.path()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Relays messages between two websocket clients.
|
||||
#[instrument(level = "debug", skip(ws), err)]
|
||||
async fn handle_ws(id: ConnectionId, ws: WebSocketStream<TcpStream>) -> Result<()> {
|
||||
let peer = {
|
||||
let mut state = STATE.lock().unwrap();
|
||||
if let Some(peer) = state.waiting.remove(&id) {
|
||||
peer
|
||||
} else {
|
||||
state.waiting.insert(id.clone(), ws);
|
||||
|
||||
debug!("connection waiting");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
debug!("started");
|
||||
|
||||
let (left_sink, left_stream) = ws.split();
|
||||
let (right_sink, right_stream) = peer.split();
|
||||
|
||||
tokio::try_join!(
|
||||
left_stream.forward(right_sink),
|
||||
right_stream.forward(left_sink),
|
||||
)?;
|
||||
|
||||
debug!("connection closed cleanly");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Relays data between a websocket client and a TCP server.
|
||||
#[instrument(level = "debug", skip(ws), err)]
|
||||
async fn handle_tcp(addr: String, ws: WebSocketStream<TcpStream>) -> Result<()> {
|
||||
let mut tcp = TcpStream::connect(addr).await?;
|
||||
|
||||
let (mut sink, mut stream) = ws.split();
|
||||
let (mut rx, mut tx) = tcp.split();
|
||||
|
||||
let is_client_closed = AtomicBool::new(false);
|
||||
|
||||
let fut_tx = async {
|
||||
while let Some(msg) = stream.next().await.transpose()? {
|
||||
let data = match msg {
|
||||
Message::Binary(data) => data,
|
||||
Message::Close(_) => {
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow!("websocket client sent non-binary message"));
|
||||
}
|
||||
};
|
||||
|
||||
tx.write_all(&data).await?;
|
||||
}
|
||||
|
||||
debug!("websocket client closed");
|
||||
is_client_closed.store(true, Ordering::Relaxed);
|
||||
|
||||
tx.shutdown().await?;
|
||||
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let fut_rx = async {
|
||||
// 16KB buffer
|
||||
let mut buf = [0; 16 * 1024];
|
||||
loop {
|
||||
let n = rx.read(&mut buf).await?;
|
||||
|
||||
if n == 0 {
|
||||
debug!("tcp server closed");
|
||||
sink.close().await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Only send to client if it hasn't closed.
|
||||
if !is_client_closed.load(Ordering::Relaxed) {
|
||||
sink.send(Message::Binary(buf[..n].to_vec())).await?;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tokio::try_join!(fut_tx, fut_rx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
29
websocket-relay/src/main.rs
Normal file
29
websocket-relay/src/main.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::{env, net::IpAddr};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
let port: u16 = env::var("PROXY_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(8080);
|
||||
let addr: IpAddr = env::var("PROXY_IP")
|
||||
.map(|addr| addr.parse().expect("should be valid IP address"))
|
||||
.unwrap_or(IpAddr::V4("127.0.0.1".parse().unwrap()));
|
||||
|
||||
let listener = TcpListener::bind((addr, port))
|
||||
.await
|
||||
.context("failed to bind to address")?;
|
||||
|
||||
info!("listening on: {}", listener.local_addr()?);
|
||||
|
||||
websocket_relay::run(listener).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user