mirror of
https://github.com/tlsnotary/tlsn-utils.git
synced 2026-01-09 15:08:05 -05:00
feat: futures-limit and futures-plex (#51)
* feat: futures-limit and futures-plex * use web time for wasm
This commit is contained in:
@@ -7,10 +7,13 @@ members = [
|
|||||||
"utils-aio",
|
"utils-aio",
|
||||||
"utils/fuzz",
|
"utils/fuzz",
|
||||||
"websocket-relay",
|
"websocket-relay",
|
||||||
|
"futures-limit",
|
||||||
|
"futures-plex",
|
||||||
"web-spawn",
|
"web-spawn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
|
futures-plex = { path = "futures-plex" }
|
||||||
serio = { path = "serio" }
|
serio = { path = "serio" }
|
||||||
spansy = { path = "spansy" }
|
spansy = { path = "spansy" }
|
||||||
tlsn-utils = { path = "utils" }
|
tlsn-utils = { path = "utils" }
|
||||||
@@ -23,6 +26,7 @@ async-tungstenite = "0.16"
|
|||||||
bincode = "1.3"
|
bincode = "1.3"
|
||||||
bytes = "1"
|
bytes = "1"
|
||||||
cfg-if = "1"
|
cfg-if = "1"
|
||||||
|
criterion = "0.5"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
futures-channel = "0.3"
|
futures-channel = "0.3"
|
||||||
futures-core = "0.3"
|
futures-core = "0.3"
|
||||||
@@ -30,6 +34,7 @@ futures-io = "0.3"
|
|||||||
futures-sink = "0.3"
|
futures-sink = "0.3"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
pin-project-lite = "0.2"
|
pin-project-lite = "0.2"
|
||||||
|
pollster = "0.4"
|
||||||
prost = "0.9"
|
prost = "0.9"
|
||||||
prost-build = "0.9"
|
prost-build = "0.9"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
|
|||||||
24
futures-limit/Cargo.toml
Normal file
24
futures-limit/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[package]
|
||||||
|
name = "futures-limit"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
bytes = { workspace = true }
|
||||||
|
futures = { workspace = true, features = ["bilock", "unstable"] }
|
||||||
|
futures-timer = { version = "3" }
|
||||||
|
pin-project-lite = { workspace = true }
|
||||||
|
|
||||||
|
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||||
|
futures-timer = { version = "3", features = ["wasm-bindgen"] }
|
||||||
|
web-time = { version = "1.1" }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
criterion = { workspace = true }
|
||||||
|
pollster = { workspace = true, features = ["macro"] }
|
||||||
|
mock_instant = "0.5"
|
||||||
|
futures-plex = { workspace = true }
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "bench"
|
||||||
|
harness = false
|
||||||
3
futures-limit/README.md
Normal file
3
futures-limit/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# futures-limit
|
||||||
|
|
||||||
|
This crate provides a rate limiting wrapper for `AsyncWrite` and a delay wrapper for `AsyncRead`.
|
||||||
56
futures-limit/benches/bench.rs
Normal file
56
futures-limit/benches/bench.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use criterion::{BenchmarkId, Criterion, Throughput, black_box, criterion_group, criterion_main};
|
||||||
|
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use futures_limit::AsyncWriteLimitExt;
|
||||||
|
use futures_plex::simplex;
|
||||||
|
use pollster::FutureExt as _;
|
||||||
|
|
||||||
|
const M: usize = 1 << 20;
|
||||||
|
|
||||||
|
pub fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let mut group = c.benchmark_group("rate");
|
||||||
|
|
||||||
|
group.throughput(Throughput::Bytes(M as u64));
|
||||||
|
group.bench_function("max", |b| {
|
||||||
|
let (mut rx, tx) = simplex(M);
|
||||||
|
let mut tx = tx.limit_rate(8 * M, usize::MAX);
|
||||||
|
let tx_buf = vec![0; M];
|
||||||
|
let mut rx_buf = vec![0; M];
|
||||||
|
|
||||||
|
b.iter(|| {
|
||||||
|
async {
|
||||||
|
futures::try_join!(tx.write_all(&tx_buf), rx.read_exact(&mut rx_buf)).unwrap();
|
||||||
|
}
|
||||||
|
.block_on();
|
||||||
|
black_box(&rx_buf);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
for mega_bits_per_sec in [10, 100, 1000] {
|
||||||
|
// 1 ms of data.
|
||||||
|
let size = mega_bits_per_sec * M / 1000 / 8;
|
||||||
|
|
||||||
|
group.throughput(Throughput::Bytes(size as u64));
|
||||||
|
group.bench_function(BenchmarkId::from_parameter(mega_bits_per_sec), |b| {
|
||||||
|
let (mut rx, tx) = simplex(M);
|
||||||
|
|
||||||
|
// 2ms burst
|
||||||
|
let burst = mega_bits_per_sec * M / 500;
|
||||||
|
|
||||||
|
let mut tx = tx.limit_rate(burst, mega_bits_per_sec * M);
|
||||||
|
|
||||||
|
let tx_buf = vec![0; size];
|
||||||
|
let mut rx_buf = vec![0; size];
|
||||||
|
|
||||||
|
b.iter(|| {
|
||||||
|
async {
|
||||||
|
futures::try_join!(tx.write_all(&tx_buf), rx.read_exact(&mut rx_buf)).unwrap();
|
||||||
|
}
|
||||||
|
.block_on();
|
||||||
|
black_box(&rx_buf);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
||||||
|
criterion_main!(benches);
|
||||||
63
futures-limit/src/bucket.rs
Normal file
63
futures-limit/src/bucket.rs
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
use std::{future::Future, pin::Pin, task::Context, time::Duration};
|
||||||
|
|
||||||
|
use futures_timer::Delay;
|
||||||
|
|
||||||
|
use crate::Instant;
|
||||||
|
|
||||||
|
/// Default interval in millis in which the write side is woken up when
|
||||||
|
/// reaching throughput limits. This sets the granularity of the rate limiting
|
||||||
|
/// and an upper bound on the throughput.
|
||||||
|
const WAKE_INTERVAL: u64 = 1;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct TokenBucket {
|
||||||
|
capacity: u64,
|
||||||
|
tokens: u64,
|
||||||
|
/// Refill rate in tokens per micro second.
|
||||||
|
rate: u64,
|
||||||
|
last_refill: Instant,
|
||||||
|
timer: Pin<Box<Delay>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenBucket {
|
||||||
|
/// Create a new `TokenBucket`.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `capacity` - Maximum number of tokens the bucket can hold.
|
||||||
|
/// * `rate` - Refill rate in tokens per microsecond.
|
||||||
|
pub(crate) fn new(capacity: u64, rate: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
capacity,
|
||||||
|
tokens: capacity,
|
||||||
|
rate,
|
||||||
|
last_refill: Instant::now(),
|
||||||
|
timer: Box::pin(Delay::new(Duration::from_millis(WAKE_INTERVAL))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn available(&self) -> u64 {
|
||||||
|
self.tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn consume(&mut self, amount: u64) {
|
||||||
|
self.tokens = self.tokens.saturating_sub(amount);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn poll_refill(&mut self, cx: &mut Context<'_>) {
|
||||||
|
self.timer.reset(Duration::from_millis(WAKE_INTERVAL));
|
||||||
|
assert!(self.timer.as_mut().poll(cx).is_pending());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn refill(&mut self) {
|
||||||
|
let now = Instant::now();
|
||||||
|
let elapsed = now.duration_since(self.last_refill).as_micros() as u64;
|
||||||
|
if elapsed == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokens = elapsed.saturating_mul(self.rate);
|
||||||
|
self.tokens = self.tokens.saturating_add(tokens).min(self.capacity);
|
||||||
|
self.last_refill = now;
|
||||||
|
}
|
||||||
|
}
|
||||||
307
futures-limit/src/delay.rs
Normal file
307
futures-limit/src/delay.rs
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
use std::{
|
||||||
|
collections::VecDeque,
|
||||||
|
io::Result,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll, Waker, ready},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use bytes::{Buf, BytesMut};
|
||||||
|
use futures::{AsyncRead, AsyncWrite, lock::BiLock};
|
||||||
|
use futures_timer::Delay as DelayTimer;
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
|
use crate::Instant;
|
||||||
|
|
||||||
|
const BUF_SIZE: usize = 16 * 1024; // 16 KiB
|
||||||
|
|
||||||
|
/// Delay wrapper for `AsyncRead`.
|
||||||
|
///
|
||||||
|
/// This wrapper will delay incoming data by the provided amount of
|
||||||
|
/// milliseconds. A corresponding future is also returned. This future should be
|
||||||
|
/// spawned onto a dedicated thread to ensure that the delay is accurate.
|
||||||
|
///
|
||||||
|
/// # Warning
|
||||||
|
///
|
||||||
|
/// Incoming data is continuously read from the underlying I/O object. This
|
||||||
|
/// buffer will continue to grow unbounded if the data is processed slower than
|
||||||
|
/// it is received.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Delay<Io> {
|
||||||
|
read: BiLock<Simplex>,
|
||||||
|
write: BiLock<Io>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> Delay<Io> {
|
||||||
|
/// Create a new delay.
|
||||||
|
///
|
||||||
|
/// Returns a future which must be polled continuously. This future should
|
||||||
|
/// be spawned onto a dedicated thread to ensure that the delay is accurate.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `io` - Underlying I/O object.
|
||||||
|
/// * `delay` - Delay in milliseconds.
|
||||||
|
pub fn new(io: Io, delay: usize) -> (Self, DelayFuture<Io>) {
|
||||||
|
let simplex = Simplex::new(delay);
|
||||||
|
|
||||||
|
let (delay_read, delay_write) = BiLock::new(simplex);
|
||||||
|
let (io_read, io_write) = BiLock::new(io);
|
||||||
|
|
||||||
|
(
|
||||||
|
Self {
|
||||||
|
read: delay_read,
|
||||||
|
write: io_write,
|
||||||
|
},
|
||||||
|
DelayFuture {
|
||||||
|
delay: delay as u64,
|
||||||
|
read: io_read,
|
||||||
|
buf: vec![0; BUF_SIZE].into_boxed_slice(),
|
||||||
|
write: delay_write,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> AsyncRead for Delay<Io>
|
||||||
|
where
|
||||||
|
Io: AsyncRead,
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<Result<usize>> {
|
||||||
|
let mut read = ready!(self.read.poll_lock(cx));
|
||||||
|
read.poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> AsyncWrite for Delay<Io>
|
||||||
|
where
|
||||||
|
Io: AsyncWrite,
|
||||||
|
{
|
||||||
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||||
|
let mut write = ready!(self.write.poll_lock(cx));
|
||||||
|
write.as_pin_mut().poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
|
let mut write = ready!(self.write.poll_lock(cx));
|
||||||
|
write.as_pin_mut().poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
|
let mut write = ready!(self.write.poll_lock(cx));
|
||||||
|
write.as_pin_mut().poll_close(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pin_project! {
|
||||||
|
/// Future returned by [`Delay::new`].
|
||||||
|
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||||
|
pub struct DelayFuture<Io> {
|
||||||
|
delay: u64,
|
||||||
|
read: BiLock<Io>,
|
||||||
|
buf: Box<[u8]>,
|
||||||
|
write: BiLock<Simplex>,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> Future for DelayFuture<Io>
|
||||||
|
where
|
||||||
|
Io: AsyncRead,
|
||||||
|
{
|
||||||
|
type Output = Result<()>;
|
||||||
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
let this = self.project();
|
||||||
|
|
||||||
|
let mut write = ready!(this.write.poll_lock(cx));
|
||||||
|
let mut read = ready!(this.read.poll_lock(cx));
|
||||||
|
|
||||||
|
let mut len = 0;
|
||||||
|
let mut closed = false;
|
||||||
|
while let Poll::Ready(res) = read.as_pin_mut().poll_read(cx, this.buf) {
|
||||||
|
match res {
|
||||||
|
Ok(n) => {
|
||||||
|
if n == 0 {
|
||||||
|
closed = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
len += n;
|
||||||
|
write.buf.extend_from_slice(&this.buf[..n]);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
write.close_write();
|
||||||
|
return Poll::Ready(Err(err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len > 0 {
|
||||||
|
write.packets.push_front(Packet {
|
||||||
|
len,
|
||||||
|
ready: Instant::now() + Duration::from_millis(*this.delay),
|
||||||
|
});
|
||||||
|
write.wake_reader();
|
||||||
|
}
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
write.close_write();
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
struct Packet {
|
||||||
|
len: usize,
|
||||||
|
/// Time when the packet is ready.
|
||||||
|
ready: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Simplex {
|
||||||
|
buf: BytesMut,
|
||||||
|
/// Packets in the buffer.
|
||||||
|
packets: VecDeque<Packet>,
|
||||||
|
/// Whether the write side has closed.
|
||||||
|
is_closed: bool,
|
||||||
|
/// Waker for the read side.
|
||||||
|
read_waker: Option<Waker>,
|
||||||
|
/// Timer to wake up the read side when the latency has elapsed.
|
||||||
|
read_timer: Pin<Box<DelayTimer>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Simplex {
|
||||||
|
fn new(delay: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buf: BytesMut::with_capacity(16 * 1024),
|
||||||
|
packets: VecDeque::new(),
|
||||||
|
is_closed: false,
|
||||||
|
read_waker: None,
|
||||||
|
read_timer: Box::pin(DelayTimer::new(Duration::from_millis(delay as u64))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn close_write(&mut self) {
|
||||||
|
self.is_closed = true;
|
||||||
|
// needs to notify any readers that no more data will come
|
||||||
|
self.wake_reader();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wake_reader(&mut self) {
|
||||||
|
if let Some(waker) = self.read_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
|
||||||
|
if self.buf.has_remaining() {
|
||||||
|
// Maximum amount of bytes that can be processed this poll.
|
||||||
|
let max_len = self.buf.remaining().min(buf.len());
|
||||||
|
|
||||||
|
// Read packets in reverse order to process the oldest packets first.
|
||||||
|
//
|
||||||
|
// Steps:
|
||||||
|
// 1. Check that the packet is ready to be read (latency has elapsed). If not,
|
||||||
|
// register a waker for when it is ready.
|
||||||
|
// 2. Read as many bytes as possible from the packet. Update the packet length
|
||||||
|
// if it is partially read.
|
||||||
|
// 3. Remove fully read packets from the queue.
|
||||||
|
let mut remaining = max_len;
|
||||||
|
let mut done_packets = 0;
|
||||||
|
let now = Instant::now();
|
||||||
|
for Packet {
|
||||||
|
len: packet_len,
|
||||||
|
ready,
|
||||||
|
} in self.packets.iter_mut().rev()
|
||||||
|
{
|
||||||
|
let time_left = ready.saturating_duration_since(now);
|
||||||
|
if time_left.as_millis() > 0 {
|
||||||
|
self.read_timer.reset(time_left);
|
||||||
|
// Poll timer to register waker.
|
||||||
|
assert!(self.read_timer.as_mut().poll(cx).is_pending());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = (*packet_len).min(remaining);
|
||||||
|
if len == *packet_len {
|
||||||
|
done_packets += 1;
|
||||||
|
} else {
|
||||||
|
// Partial read, update packet length.
|
||||||
|
*packet_len -= len;
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining -= len;
|
||||||
|
if remaining == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if remaining == max_len {
|
||||||
|
// No packets are ready to be read, so we need to wait for the timer to expire.
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove packets that have been fully read.
|
||||||
|
self.packets.truncate(self.packets.len() - done_packets);
|
||||||
|
|
||||||
|
let len = max_len - remaining;
|
||||||
|
buf[..len].copy_from_slice(&self.buf[..len]);
|
||||||
|
self.buf.advance(len);
|
||||||
|
|
||||||
|
Poll::Ready(Ok(len))
|
||||||
|
} else if self.is_closed {
|
||||||
|
Poll::Ready(Ok(0))
|
||||||
|
} else {
|
||||||
|
self.read_waker = Some(cx.waker().clone());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::{pin::pin, time::Duration};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use futures::{AsyncWriteExt, future::poll_fn, poll};
|
||||||
|
use futures_plex::simplex;
|
||||||
|
use mock_instant::thread_local::MockClock;
|
||||||
|
|
||||||
|
#[pollster::test]
|
||||||
|
async fn test_delay() {
|
||||||
|
let data = b"hello world";
|
||||||
|
const DELAY: usize = 1;
|
||||||
|
|
||||||
|
let (read, mut write) = simplex(100);
|
||||||
|
let (mut delay, mut fut) = Delay::new(read, DELAY);
|
||||||
|
|
||||||
|
write.write_all(data).await.unwrap();
|
||||||
|
write.flush().await.unwrap();
|
||||||
|
|
||||||
|
assert!(poll!(&mut fut).is_pending());
|
||||||
|
|
||||||
|
let mut buf = vec![0u8; 11];
|
||||||
|
let res = poll!(poll_fn(|cx| pin!(&mut delay).poll_read(cx, &mut buf)));
|
||||||
|
|
||||||
|
// Data should not be available yet.
|
||||||
|
assert!(res.is_pending());
|
||||||
|
|
||||||
|
MockClock::advance(Duration::from_millis(DELAY as u64));
|
||||||
|
|
||||||
|
let res = poll!(poll_fn(|cx| pin!(&mut delay).poll_read(cx, &mut buf)));
|
||||||
|
|
||||||
|
// Data should be available now.
|
||||||
|
assert!(matches!(res, Poll::Ready(Ok(11))));
|
||||||
|
|
||||||
|
write.close().await.unwrap();
|
||||||
|
|
||||||
|
assert!(poll!(&mut fut).is_ready());
|
||||||
|
}
|
||||||
|
}
|
||||||
54
futures-limit/src/lib.rs
Normal file
54
futures-limit/src/lib.rs
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#![doc = include_str!("../README.md")]
|
||||||
|
|
||||||
|
pub(crate) mod bucket;
|
||||||
|
mod delay;
|
||||||
|
mod rate;
|
||||||
|
|
||||||
|
pub use delay::{Delay, DelayFuture};
|
||||||
|
pub use rate::Rate;
|
||||||
|
|
||||||
|
use futures::{AsyncRead, AsyncWrite};
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) use mock_instant::thread_local::Instant;
|
||||||
|
#[cfg(all(not(test), not(target_arch = "wasm32")))]
|
||||||
|
pub(crate) use std::time::Instant;
|
||||||
|
#[cfg(all(not(test), target_arch = "wasm32"))]
|
||||||
|
pub(crate) use web_time::Instant;
|
||||||
|
|
||||||
|
/// Extension trait for `AsyncWrite`.
|
||||||
|
pub trait AsyncWriteLimitExt: AsyncWrite {
|
||||||
|
/// Limit the write rate of the underlying writer.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `burst` - Maximum burst size in bits.
|
||||||
|
/// * `rate` - Maximum write rate in bits per second.
|
||||||
|
fn limit_rate(self, burst: usize, rate: usize) -> Rate<Self>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
Rate::new(self, burst, rate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AsyncWriteLimitExt for T where T: AsyncWrite {}
|
||||||
|
|
||||||
|
/// Extension trait for `AsyncRead`.
|
||||||
|
pub trait AsyncReadDelayExt: AsyncRead {
|
||||||
|
/// Delays incoming data by the given amount of milliseconds.
|
||||||
|
///
|
||||||
|
/// Returns a future which must be polled continuously. See [`Delay`] for
|
||||||
|
/// more details.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `delay` - Delay in milliseconds.
|
||||||
|
fn delay(self, delay: usize) -> (Delay<Self>, DelayFuture<Self>)
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
Delay::new(self, delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AsyncReadDelayExt for T where T: AsyncRead {}
|
||||||
176
futures-limit/src/rate.rs
Normal file
176
futures-limit/src/rate.rs
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
use std::{
|
||||||
|
io::{IoSliceMut, Result},
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
use futures::{AsyncRead, AsyncWrite};
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
|
use crate::bucket::TokenBucket;
|
||||||
|
|
||||||
|
const M: u64 = 1_000_000;
|
||||||
|
|
||||||
|
pin_project! {
|
||||||
|
/// Rate limiting wrapper for `AsyncWrite`.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Rate<Io> {
|
||||||
|
#[pin] io: Io,
|
||||||
|
bucket: TokenBucket,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> Rate<Io> {
|
||||||
|
/// Create a new rate limiter.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `io` - Underlying I/O object.
|
||||||
|
/// * `burst` - Maximum burst size in bits.
|
||||||
|
/// * `rate` - Maximum write rate in bits per second.
|
||||||
|
pub fn new(io: Io, burst: usize, rate: usize) -> Self {
|
||||||
|
// Bucketing is done with microsecond granularity.
|
||||||
|
// Each token represents one-millionth of a byte.
|
||||||
|
let tokens = ((burst as u64) * M).div_ceil(8);
|
||||||
|
let tokens_per_micro_sec = (rate as u64).div_ceil(8);
|
||||||
|
|
||||||
|
let bucket = TokenBucket::new(tokens, tokens_per_micro_sec);
|
||||||
|
|
||||||
|
Self { io, bucket }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> Rate<Io>
|
||||||
|
where
|
||||||
|
Io: AsyncWrite,
|
||||||
|
{
|
||||||
|
fn poll_write_internal(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize>> {
|
||||||
|
let this = self.project();
|
||||||
|
|
||||||
|
this.bucket.refill();
|
||||||
|
|
||||||
|
let available = (this.bucket.available() / M) as usize;
|
||||||
|
if available == 0 {
|
||||||
|
this.bucket.poll_refill(cx);
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = buf.len().min(available);
|
||||||
|
|
||||||
|
let res = this.io.poll_write(cx, &buf[..len]);
|
||||||
|
|
||||||
|
if let Poll::Ready(Ok(n)) = &res {
|
||||||
|
this.bucket.consume((*n as u64) * M);
|
||||||
|
}
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> AsyncWrite for Rate<Io>
|
||||||
|
where
|
||||||
|
Io: AsyncWrite,
|
||||||
|
{
|
||||||
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||||
|
self.poll_write_internal(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
|
self.project().io.poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
|
self.project().io.poll_close(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Io> AsyncRead for Rate<Io>
|
||||||
|
where
|
||||||
|
Io: AsyncRead,
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<Result<usize>> {
|
||||||
|
self.project().io.poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_read_vectored(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
bufs: &mut [IoSliceMut<'_>],
|
||||||
|
) -> Poll<Result<usize>> {
|
||||||
|
self.project().io.poll_read_vectored(cx, bufs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use futures::{AsyncWriteExt, io::sink, poll};
|
||||||
|
use mock_instant::thread_local::MockClock;
|
||||||
|
|
||||||
|
// Tests that the burst size is respected.
|
||||||
|
#[pollster::test]
|
||||||
|
async fn test_rate_burst() {
|
||||||
|
let data = b"hello world";
|
||||||
|
|
||||||
|
let mut io = Rate::new(sink(), (data.len() - 1) * 8, 0);
|
||||||
|
|
||||||
|
let n = io.write(data).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(n, data.len() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that the burst will allow all data to be written when it is less than
|
||||||
|
// the burst size.
|
||||||
|
#[pollster::test]
|
||||||
|
async fn test_rate_burst_all() {
|
||||||
|
let data = b"hello world";
|
||||||
|
|
||||||
|
let mut io = Rate::new(sink(), data.len() * 8, 0);
|
||||||
|
|
||||||
|
let n = io.write(data).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(n, data.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pollster::test]
|
||||||
|
async fn test_rate_limit() {
|
||||||
|
let data = b"hello world";
|
||||||
|
|
||||||
|
let mut io = Rate::new(sink(), data.len() * 8, 8);
|
||||||
|
|
||||||
|
let n = io.write(data).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(n, data.len());
|
||||||
|
|
||||||
|
let mut write = io.write(data);
|
||||||
|
|
||||||
|
assert!(poll!(&mut write).is_pending());
|
||||||
|
|
||||||
|
MockClock::advance(Duration::from_secs(1));
|
||||||
|
|
||||||
|
let Poll::Ready(Ok(n)) = poll!(write) else {
|
||||||
|
panic!("poll should be ready");
|
||||||
|
};
|
||||||
|
|
||||||
|
// 1 byte per second.
|
||||||
|
assert_eq!(n, 1);
|
||||||
|
|
||||||
|
let mut write = io.write(data);
|
||||||
|
|
||||||
|
assert!(poll!(&mut write).is_pending());
|
||||||
|
}
|
||||||
|
}
|
||||||
334
futures-limit/src/simplex.rs
Normal file
334
futures-limit/src/simplex.rs
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
use std::{
|
||||||
|
collections::VecDeque,
|
||||||
|
future::Future,
|
||||||
|
io::Result,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll, Waker},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use bytes::{Buf, BytesMut};
|
||||||
|
use futures::{
|
||||||
|
io::{ReadHalf, WriteHalf},
|
||||||
|
AsyncRead, AsyncReadExt, AsyncWrite,
|
||||||
|
};
|
||||||
|
use futures_timer::Delay;
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
|
use crate::Instant;
|
||||||
|
|
||||||
|
/// Returns a simplex connection pair.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `params` - Parameters for the connection.
|
||||||
|
pub fn simplex(params: Params) -> (ReadHalf<Simplex>, WriteHalf<Simplex>) {
|
||||||
|
Simplex::new(params).split()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
struct Packet {
|
||||||
|
len: usize,
|
||||||
|
/// Time when the packet is ready.
|
||||||
|
ready: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unidirectional pipe with configurable bandwidth, latency and buffer size.
|
||||||
|
///
|
||||||
|
/// Implementation is based on the simplex in `tokio`.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Simplex {
|
||||||
|
params: Params,
|
||||||
|
buf: BytesMut,
|
||||||
|
/// Packets in the buffer.
|
||||||
|
packets: VecDeque<Packet>,
|
||||||
|
/// Whether the write side has closed.
|
||||||
|
is_closed: bool,
|
||||||
|
/// Waker for the read side.
|
||||||
|
read_waker: Option<Waker>,
|
||||||
|
/// Read bucket.
|
||||||
|
read_bucket: TokenBucket,
|
||||||
|
/// Timer to wake up the read side when the latency has elapsed.
|
||||||
|
read_timer: Pin<Box<Delay>>,
|
||||||
|
/// Waker for the write side.
|
||||||
|
write_waker: Option<Waker>,
|
||||||
|
/// Write bucket.
|
||||||
|
write_bucket: TokenBucket,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Simplex {
|
||||||
|
/// Create a new `Simplex`.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `tx_rate` or `rx_rate` are less than 8 bits per second.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `params` - Parameters for the connection.
|
||||||
|
pub fn new(params: Params) -> Self {
|
||||||
|
assert!(
|
||||||
|
params.tx_rate >= 8,
|
||||||
|
"tx_rate must be at least 8 bits per second"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
params.rx_rate >= 8,
|
||||||
|
"rx_rate must be at least 8 bits per second"
|
||||||
|
);
|
||||||
|
|
||||||
|
let tx_bytes_per_sec = params.tx_rate >> 3;
|
||||||
|
let rx_bytes_per_sec = params.rx_rate >> 3;
|
||||||
|
|
||||||
|
let write_bucket = TokenBucket::new(DEFAULT_BUCKET_CAPACITY, tx_bytes_per_sec >> 20);
|
||||||
|
let read_bucket = TokenBucket::new(DEFAULT_BUCKET_CAPACITY, rx_bytes_per_sec >> 20);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
params,
|
||||||
|
buf: BytesMut::new(),
|
||||||
|
packets: VecDeque::new(),
|
||||||
|
is_closed: false,
|
||||||
|
read_waker: None,
|
||||||
|
read_bucket,
|
||||||
|
read_timer: Box::pin(Delay::new(Duration::from_millis(0))),
|
||||||
|
write_waker: None,
|
||||||
|
write_bucket,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn close_write(&mut self) {
|
||||||
|
self.is_closed = true;
|
||||||
|
// needs to notify any readers that no more data will come
|
||||||
|
if let Some(waker) = self.read_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_internal(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||||
|
if self.is_closed {
|
||||||
|
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = self.params.buf_size - self.buf.len();
|
||||||
|
if len == 0 {
|
||||||
|
// Buffer is full, so we need to wait for some data to be read.
|
||||||
|
self.write_waker = Some(cx.waker().clone());
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.write_bucket.refill();
|
||||||
|
let len = len.min(self.write_bucket.available());
|
||||||
|
if len == 0 {
|
||||||
|
// No tokens available, so we need to wait for the bucket to refill.
|
||||||
|
assert!(self.write_bucket.poll_refill(cx).is_pending());
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = len.min(buf.len());
|
||||||
|
self.buf.extend_from_slice(&buf[..len]);
|
||||||
|
self.write_bucket.consume(len);
|
||||||
|
self.packets.push_front(Packet {
|
||||||
|
len,
|
||||||
|
ready: Instant::now() + Duration::from_millis(self.params.latency as u64),
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(waker) = self.read_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Ok(len))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read_internal(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
|
||||||
|
if self.buf.has_remaining() {
|
||||||
|
self.read_bucket.refill();
|
||||||
|
if self.read_bucket.is_empty() {
|
||||||
|
// No tokens available, so we need to wait for the bucket to refill.
|
||||||
|
assert!(self.read_bucket.poll_refill(cx).is_pending());
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maximum amount of bytes that can be processed this poll.
|
||||||
|
let max_len = self
|
||||||
|
.buf
|
||||||
|
.remaining()
|
||||||
|
.min(buf.len())
|
||||||
|
.min(self.read_bucket.available());
|
||||||
|
|
||||||
|
// Read packets in reverse order to process the oldest packets first.
|
||||||
|
//
|
||||||
|
// Steps:
|
||||||
|
// 1. Check that the packet is ready to be read (latency has elapsed). If not,
|
||||||
|
// register a waker for when it is ready.
|
||||||
|
// 2. Read as many bytes as possible from the packet. Update the packet length
|
||||||
|
// if it is partially read.
|
||||||
|
// 3. Remove fully read packets from the queue.
|
||||||
|
let mut remaining = max_len;
|
||||||
|
let mut complete = 0;
|
||||||
|
let now = Instant::now();
|
||||||
|
for Packet {
|
||||||
|
len: packet_len,
|
||||||
|
ready,
|
||||||
|
} in self.packets.iter_mut().rev()
|
||||||
|
{
|
||||||
|
let time_left = ready.saturating_duration_since(now);
|
||||||
|
if time_left.as_millis() > 0 {
|
||||||
|
self.read_timer.reset(time_left);
|
||||||
|
// Poll timer to register waker.
|
||||||
|
assert!(self.read_timer.as_mut().poll(cx).is_pending());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = (*packet_len).min(remaining);
|
||||||
|
if len == *packet_len {
|
||||||
|
complete += 1;
|
||||||
|
} else {
|
||||||
|
// Partial read, update packet length.
|
||||||
|
*packet_len -= len;
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining -= len;
|
||||||
|
if remaining == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if remaining == max_len {
|
||||||
|
// No packets are ready to be read, so we need to wait for the timer to expire.
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove packets that have been fully read.
|
||||||
|
self.packets.truncate(self.packets.len() - complete);
|
||||||
|
|
||||||
|
let len = max_len - remaining;
|
||||||
|
buf[..len].copy_from_slice(&self.buf[..len]);
|
||||||
|
self.buf.advance(len);
|
||||||
|
self.read_bucket.consume(len);
|
||||||
|
if len > 0 {
|
||||||
|
// The passed `buf` might have been empty, don't wake up if
|
||||||
|
// no bytes have been moved.
|
||||||
|
if let Some(waker) = self.write_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Ok(len))
|
||||||
|
} else if self.is_closed {
|
||||||
|
Poll::Ready(Ok(0))
|
||||||
|
} else {
|
||||||
|
self.read_waker = Some(cx.waker().clone());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for Simplex {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize>> {
|
||||||
|
self.poll_write_internal(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
|
self.close_write();
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for Simplex {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<Result<usize>> {
|
||||||
|
self.poll_read_internal(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use futures::{poll, AsyncReadExt, AsyncWriteExt};
|
||||||
|
use mock_instant::thread_local::MockClock;
|
||||||
|
use pollster::FutureExt;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simplex() {
|
||||||
|
async {
|
||||||
|
let mut io = Simplex::new(Params {
|
||||||
|
buf_size: 1024,
|
||||||
|
tx_rate: 1024 * 8,
|
||||||
|
rx_rate: 1024 * 8,
|
||||||
|
latency: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let data = b"hello world";
|
||||||
|
io.write_all(data).await.unwrap();
|
||||||
|
|
||||||
|
let mut buf = [0; 1024];
|
||||||
|
let len = io.read(&mut buf).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(len, data.len());
|
||||||
|
assert_eq!(&buf[..len], data);
|
||||||
|
}
|
||||||
|
.block_on()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simplex_burst_write() {
|
||||||
|
async {
|
||||||
|
let mut io = Simplex::new(Params {
|
||||||
|
buf_size: DEFAULT_BUCKET_CAPACITY,
|
||||||
|
tx_rate: 8,
|
||||||
|
rx_rate: 1024 * 8,
|
||||||
|
latency: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let data = vec![0; DEFAULT_BUCKET_CAPACITY];
|
||||||
|
|
||||||
|
// Burst write should accept the full buffer.
|
||||||
|
assert_eq!(
|
||||||
|
poll!(io.write(&data)).map(|r| r.unwrap()),
|
||||||
|
Poll::Ready(data.len())
|
||||||
|
);
|
||||||
|
assert_eq!(io.write_bucket.tokens, 0);
|
||||||
|
}
|
||||||
|
.block_on()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simplex_latency() {
|
||||||
|
async {
|
||||||
|
let mut io = Simplex::new(Params {
|
||||||
|
buf_size: DEFAULT_BUCKET_CAPACITY,
|
||||||
|
tx_rate: 8,
|
||||||
|
rx_rate: 1024 * 8,
|
||||||
|
latency: 2,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut data = vec![0; DEFAULT_BUCKET_CAPACITY];
|
||||||
|
io.write_all(&data).await.unwrap();
|
||||||
|
|
||||||
|
// No time has elapsed, so no data should be available.
|
||||||
|
assert_eq!(poll!(io.read(&mut data)).map(|r| r.unwrap()), Poll::Pending);
|
||||||
|
|
||||||
|
// Latency still hasn't elapsed.
|
||||||
|
MockClock::advance(Duration::from_millis(1));
|
||||||
|
assert_eq!(poll!(io.read(&mut data)).map(|r| r.unwrap()), Poll::Pending);
|
||||||
|
|
||||||
|
// Latency has elapsed, so data should be available.
|
||||||
|
MockClock::advance(Duration::from_millis(1));
|
||||||
|
assert_eq!(
|
||||||
|
poll!(io.read(&mut data)).map(|r| r.unwrap()),
|
||||||
|
Poll::Ready(DEFAULT_BUCKET_CAPACITY)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
.block_on()
|
||||||
|
}
|
||||||
|
}
|
||||||
10
futures-plex/Cargo.toml
Normal file
10
futures-plex/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[package]
|
||||||
|
name = "futures-plex"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
description = "Port of tokio's `SimplexStream` and `DuplexStream` for the `futures` ecosystem."
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
bytes = { version = "1" }
|
||||||
|
futures-io = { version = "0.3" }
|
||||||
|
futures-util = { version = "0.3", default-features = false, features = ["io"] }
|
||||||
5
futures-plex/README.md
Normal file
5
futures-plex/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# futures-plex
|
||||||
|
|
||||||
|
Port of tokio's `SimplexStream` and `DuplexStream` for the `futures` ecosystem.
|
||||||
|
|
||||||
|
This crate provides in-memory implementations for `AsyncRead` and `AsyncWrite`.
|
||||||
344
futures-plex/src/lib.rs
Normal file
344
futures-plex/src/lib.rs
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
#![doc = include_str!("../README.md")]
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
pin::Pin,
|
||||||
|
task::{self, Poll, Waker},
|
||||||
|
};
|
||||||
|
|
||||||
|
use bytes::{Buf, BytesMut};
|
||||||
|
use futures_io::{AsyncRead, AsyncWrite};
|
||||||
|
use futures_util::{
|
||||||
|
AsyncReadExt,
|
||||||
|
io::{ReadHalf, WriteHalf},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A bidirectional pipe to read and write bytes in memory.
|
||||||
|
///
|
||||||
|
/// A pair of `DuplexStream`s are created together, and they act as a "channel"
|
||||||
|
/// that can be used as in-memory IO types. Writing to one of the pairs will
|
||||||
|
/// allow that data to be read from the other, and vice versa.
|
||||||
|
///
|
||||||
|
/// # Closing a `DuplexStream`
|
||||||
|
///
|
||||||
|
/// If one end of the `DuplexStream` channel is dropped, any pending reads on
|
||||||
|
/// the other side will continue to read data until the buffer is drained, then
|
||||||
|
/// they will signal EOF by returning 0 bytes. Any writes to the other side,
|
||||||
|
/// including pending ones (that are waiting for free space in the buffer) will
|
||||||
|
/// return `Err(BrokenPipe)` immediately.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # async fn ex() -> std::io::Result<()> {
|
||||||
|
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
/// let (mut client, mut server) = futures_plex::duplex(64);
|
||||||
|
///
|
||||||
|
/// client.write_all(b"ping").await?;
|
||||||
|
///
|
||||||
|
/// let mut buf = [0u8; 4];
|
||||||
|
/// server.read_exact(&mut buf).await?;
|
||||||
|
/// assert_eq!(&buf, b"ping");
|
||||||
|
///
|
||||||
|
/// server.write_all(b"pong").await?;
|
||||||
|
///
|
||||||
|
/// client.read_exact(&mut buf).await?;
|
||||||
|
/// assert_eq!(&buf, b"pong");
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DuplexStream {
|
||||||
|
read: ReadHalf<SimplexStream>,
|
||||||
|
write: WriteHalf<SimplexStream>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A unidirectional pipe to read and write bytes in memory.
|
||||||
|
///
|
||||||
|
/// It can be constructed by [`simplex`] function which will create a pair of
|
||||||
|
/// reader and writer or by calling [`SimplexStream::new_unsplit`] that will
|
||||||
|
/// create a handle for both reading and writing.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # async fn ex() -> std::io::Result<()> {
|
||||||
|
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
/// let (mut receiver, mut sender) = futures_plex::simplex(64);
|
||||||
|
///
|
||||||
|
/// sender.write_all(b"ping").await?;
|
||||||
|
///
|
||||||
|
/// let mut buf = [0u8; 4];
|
||||||
|
/// receiver.read_exact(&mut buf).await?;
|
||||||
|
/// assert_eq!(&buf, b"ping");
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SimplexStream {
|
||||||
|
/// The buffer storing the bytes written, also read from.
|
||||||
|
///
|
||||||
|
/// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
|
||||||
|
/// functionality already. Additionally, it can try to copy data in the
|
||||||
|
/// same buffer if there read index has advanced far enough.
|
||||||
|
buffer: BytesMut,
|
||||||
|
/// Determines if the write side has been closed.
|
||||||
|
is_closed: bool,
|
||||||
|
/// The maximum amount of bytes that can be written before returning
|
||||||
|
/// `Poll::Pending`.
|
||||||
|
max_buf_size: usize,
|
||||||
|
/// If the `read` side has been polled and is pending, this is the waker
|
||||||
|
/// for that parked task.
|
||||||
|
read_waker: Option<Waker>,
|
||||||
|
/// If the `write` side has filled the `max_buf_size` and returned
|
||||||
|
/// `Poll::Pending`, this is the waker for that parked task.
|
||||||
|
write_waker: Option<Waker>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== impl DuplexStream =====
|
||||||
|
|
||||||
|
/// Create a new pair of `DuplexStream`s that act like a pair of connected
|
||||||
|
/// sockets.
|
||||||
|
///
|
||||||
|
/// The `max_buf_size` argument is the maximum amount of bytes that can be
|
||||||
|
/// written to a side before the write returns `Poll::Pending`.
|
||||||
|
pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
|
||||||
|
let (read_0, write_0) = SimplexStream::new_unsplit(max_buf_size).split();
|
||||||
|
let (read_1, write_1) = SimplexStream::new_unsplit(max_buf_size).split();
|
||||||
|
|
||||||
|
(
|
||||||
|
DuplexStream {
|
||||||
|
read: read_0,
|
||||||
|
write: write_1,
|
||||||
|
},
|
||||||
|
DuplexStream {
|
||||||
|
read: read_1,
|
||||||
|
write: write_0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for DuplexStream {
|
||||||
|
// Previous rustc required this `self` to be `mut`, even though newer
|
||||||
|
// versions recognize it isn't needed to call `lock()`. So for
|
||||||
|
// compatibility, we include the `mut` and `allow` the lint.
|
||||||
|
//
|
||||||
|
// See https://github.com/rust-lang/rust/issues/73592
|
||||||
|
#[allow(unused_mut)]
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.read).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for DuplexStream {
|
||||||
|
#[allow(unused_mut)]
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.write).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
bufs: &[std::io::IoSlice<'_>],
|
||||||
|
) -> Poll<Result<usize, std::io::Error>> {
|
||||||
|
Pin::new(&mut self.write).poll_write_vectored(cx, bufs)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused_mut)]
|
||||||
|
fn poll_flush(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
Pin::new(&mut self.write).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused_mut)]
|
||||||
|
fn poll_close(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
Pin::new(&mut self.write).poll_close(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== impl SimplexStream =====
|
||||||
|
|
||||||
|
/// Creates unidirectional buffer that acts like in memory pipe.
|
||||||
|
///
|
||||||
|
/// The `max_buf_size` argument is the maximum amount of bytes that can be
|
||||||
|
/// written to a buffer before the it returns `Poll::Pending`.
|
||||||
|
///
|
||||||
|
/// # Reunite reader and writer
|
||||||
|
///
|
||||||
|
/// The reader and writer half can be unified into a single structure
|
||||||
|
/// of `SimplexStream` that supports both reading and writing or
|
||||||
|
/// the `SimplexStream` can be already created as unified structure
|
||||||
|
/// using [`SimplexStream::new_unsplit()`].
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # async fn ex() -> std::io::Result<()> {
|
||||||
|
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
/// let (reader, writer) = futures_plex::simplex(64);
|
||||||
|
/// let mut simplex_stream = reader.reunite(writer).unwrap();
|
||||||
|
/// simplex_stream.write_all(b"hello").await?;
|
||||||
|
///
|
||||||
|
/// let mut buf = [0u8; 5];
|
||||||
|
/// simplex_stream.read_exact(&mut buf).await?;
|
||||||
|
/// assert_eq!(&buf, b"hello");
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
|
||||||
|
SimplexStream::new_unsplit(max_buf_size).split()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SimplexStream {
|
||||||
|
/// Creates unidirectional buffer that acts like in memory pipe. To create
|
||||||
|
/// split version with separate reader and writer you can use
|
||||||
|
/// [`simplex`] function.
|
||||||
|
///
|
||||||
|
/// The `max_buf_size` argument is the maximum amount of bytes that can be
|
||||||
|
/// written to a buffer before the it returns `Poll::Pending`.
|
||||||
|
pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
|
||||||
|
SimplexStream {
|
||||||
|
buffer: BytesMut::new(),
|
||||||
|
is_closed: false,
|
||||||
|
max_buf_size,
|
||||||
|
read_waker: None,
|
||||||
|
write_waker: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn close_write(&mut self) {
|
||||||
|
self.is_closed = true;
|
||||||
|
// needs to notify any readers that no more data will come
|
||||||
|
if let Some(waker) = self.read_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read_internal(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
if self.buffer.has_remaining() {
|
||||||
|
let len = self.buffer.remaining().min(buf.len());
|
||||||
|
buf[..len].copy_from_slice(&self.buffer[..len]);
|
||||||
|
self.buffer.advance(len);
|
||||||
|
if len > 0 {
|
||||||
|
// The passed `buf` might have been empty, don't wake up if
|
||||||
|
// no bytes have been moved.
|
||||||
|
if let Some(waker) = self.write_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(len))
|
||||||
|
} else if self.is_closed {
|
||||||
|
Poll::Ready(Ok(0))
|
||||||
|
} else {
|
||||||
|
self.read_waker = Some(cx.waker().clone());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_internal(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
if self.is_closed {
|
||||||
|
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
|
||||||
|
}
|
||||||
|
let avail = self.max_buf_size - self.buffer.len();
|
||||||
|
if avail == 0 {
|
||||||
|
self.write_waker = Some(cx.waker().clone());
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = buf.len().min(avail);
|
||||||
|
self.buffer.extend_from_slice(&buf[..len]);
|
||||||
|
if let Some(waker) = self.read_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(len))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored_internal(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
bufs: &[std::io::IoSlice<'_>],
|
||||||
|
) -> Poll<Result<usize, std::io::Error>> {
|
||||||
|
if self.is_closed {
|
||||||
|
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
|
||||||
|
}
|
||||||
|
let avail = self.max_buf_size - self.buffer.len();
|
||||||
|
if avail == 0 {
|
||||||
|
self.write_waker = Some(cx.waker().clone());
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rem = avail;
|
||||||
|
for buf in bufs {
|
||||||
|
if rem == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = buf.len().min(rem);
|
||||||
|
self.buffer.extend_from_slice(&buf[..len]);
|
||||||
|
rem -= len;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(waker) = self.read_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(avail - rem))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for SimplexStream {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
self.poll_read_internal(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for SimplexStream {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
self.poll_write_internal(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
bufs: &[std::io::IoSlice<'_>],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
self.poll_write_vectored_internal(cx, bufs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
_: &mut task::Context<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
self.close_write();
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
imports_granularity = "Crate"
|
imports_granularity = "Crate"
|
||||||
wrap_comments = true
|
wrap_comments = true
|
||||||
|
|||||||
Reference in New Issue
Block a user