feat: futures-limit and futures-plex (#51)

* feat: futures-limit and futures-plex

* use web time for wasm
This commit is contained in:
sinu.eth
2025-03-03 10:10:54 -08:00
committed by GitHub
parent 87cba96727
commit 9512cc1c44
13 changed files with 1382 additions and 1 deletions

View File

@@ -7,10 +7,13 @@ members = [
"utils-aio",
"utils/fuzz",
"websocket-relay",
"futures-limit",
"futures-plex",
"web-spawn",
]
[workspace.dependencies]
futures-plex = { path = "futures-plex" }
serio = { path = "serio" }
spansy = { path = "spansy" }
tlsn-utils = { path = "utils" }
@@ -23,6 +26,7 @@ async-tungstenite = "0.16"
bincode = "1.3"
bytes = "1"
cfg-if = "1"
criterion = "0.5"
futures = "0.3"
futures-channel = "0.3"
futures-core = "0.3"
@@ -30,6 +34,7 @@ futures-io = "0.3"
futures-sink = "0.3"
futures-util = "0.3"
pin-project-lite = "0.2"
pollster = "0.4"
prost = "0.9"
prost-build = "0.9"
rand = "0.8"

24
futures-limit/Cargo.toml Normal file
View File

@@ -0,0 +1,24 @@
[package]
name = "futures-limit"
version = "0.1.0"
edition = "2024"
[dependencies]
bytes = { workspace = true }
futures = { workspace = true, features = ["bilock", "unstable"] }
futures-timer = { version = "3" }
pin-project-lite = { workspace = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
futures-timer = { version = "3", features = ["wasm-bindgen"] }
web-time = { version = "1.1" }
[dev-dependencies]
criterion = { workspace = true }
pollster = { workspace = true, features = ["macro"] }
mock_instant = "0.5"
futures-plex = { workspace = true }
[[bench]]
name = "bench"
harness = false

3
futures-limit/README.md Normal file
View File

@@ -0,0 +1,3 @@
# futures-limit
This crate provides a rate limiting wrapper for `AsyncWrite` and a delay wrapper for `AsyncRead`.

View File

@@ -0,0 +1,56 @@
use criterion::{BenchmarkId, Criterion, Throughput, black_box, criterion_group, criterion_main};
use futures::{AsyncReadExt, AsyncWriteExt};
use futures_limit::AsyncWriteLimitExt;
use futures_plex::simplex;
use pollster::FutureExt as _;
const M: usize = 1 << 20;
pub fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("rate");
group.throughput(Throughput::Bytes(M as u64));
group.bench_function("max", |b| {
let (mut rx, tx) = simplex(M);
let mut tx = tx.limit_rate(8 * M, usize::MAX);
let tx_buf = vec![0; M];
let mut rx_buf = vec![0; M];
b.iter(|| {
async {
futures::try_join!(tx.write_all(&tx_buf), rx.read_exact(&mut rx_buf)).unwrap();
}
.block_on();
black_box(&rx_buf);
});
});
for mega_bits_per_sec in [10, 100, 1000] {
// 1 ms of data.
let size = mega_bits_per_sec * M / 1000 / 8;
group.throughput(Throughput::Bytes(size as u64));
group.bench_function(BenchmarkId::from_parameter(mega_bits_per_sec), |b| {
let (mut rx, tx) = simplex(M);
// 2ms burst
let burst = mega_bits_per_sec * M / 500;
let mut tx = tx.limit_rate(burst, mega_bits_per_sec * M);
let tx_buf = vec![0; size];
let mut rx_buf = vec![0; size];
b.iter(|| {
async {
futures::try_join!(tx.write_all(&tx_buf), rx.read_exact(&mut rx_buf)).unwrap();
}
.block_on();
black_box(&rx_buf);
});
});
}
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@@ -0,0 +1,63 @@
use std::{future::Future, pin::Pin, task::Context, time::Duration};
use futures_timer::Delay;
use crate::Instant;
/// Default interval in millis in which the write side is woken up when
/// reaching throughput limits. This sets the granularity of the rate limiting
/// and an upper bound on the throughput.
const WAKE_INTERVAL: u64 = 1;
#[derive(Debug)]
pub(crate) struct TokenBucket {
capacity: u64,
tokens: u64,
/// Refill rate in tokens per micro second.
rate: u64,
last_refill: Instant,
timer: Pin<Box<Delay>>,
}
impl TokenBucket {
/// Create a new `TokenBucket`.
///
/// # Arguments
///
/// * `capacity` - Maximum number of tokens the bucket can hold.
/// * `rate` - Refill rate in tokens per microsecond.
pub(crate) fn new(capacity: u64, rate: u64) -> Self {
Self {
capacity,
tokens: capacity,
rate,
last_refill: Instant::now(),
timer: Box::pin(Delay::new(Duration::from_millis(WAKE_INTERVAL))),
}
}
pub(crate) fn available(&self) -> u64 {
self.tokens
}
pub(crate) fn consume(&mut self, amount: u64) {
self.tokens = self.tokens.saturating_sub(amount);
}
pub(crate) fn poll_refill(&mut self, cx: &mut Context<'_>) {
self.timer.reset(Duration::from_millis(WAKE_INTERVAL));
assert!(self.timer.as_mut().poll(cx).is_pending());
}
pub(crate) fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_micros() as u64;
if elapsed == 0 {
return;
}
let tokens = elapsed.saturating_mul(self.rate);
self.tokens = self.tokens.saturating_add(tokens).min(self.capacity);
self.last_refill = now;
}
}

307
futures-limit/src/delay.rs Normal file
View File

@@ -0,0 +1,307 @@
use std::{
collections::VecDeque,
io::Result,
pin::Pin,
task::{Context, Poll, Waker, ready},
time::Duration,
};
use bytes::{Buf, BytesMut};
use futures::{AsyncRead, AsyncWrite, lock::BiLock};
use futures_timer::Delay as DelayTimer;
use pin_project_lite::pin_project;
use crate::Instant;
const BUF_SIZE: usize = 16 * 1024; // 16 KiB
/// Delay wrapper for `AsyncRead`.
///
/// This wrapper will delay incoming data by the provided amount of
/// milliseconds. A corresponding future is also returned. This future should be
/// spawned onto a dedicated thread to ensure that the delay is accurate.
///
/// # Warning
///
/// Incoming data is continuously read from the underlying I/O object. This
/// buffer will continue to grow unbounded if the data is processed slower than
/// it is received.
#[derive(Debug)]
pub struct Delay<Io> {
read: BiLock<Simplex>,
write: BiLock<Io>,
}
impl<Io> Delay<Io> {
/// Create a new delay.
///
/// Returns a future which must be polled continuously. This future should
/// be spawned onto a dedicated thread to ensure that the delay is accurate.
///
/// # Arguments
///
/// * `io` - Underlying I/O object.
/// * `delay` - Delay in milliseconds.
pub fn new(io: Io, delay: usize) -> (Self, DelayFuture<Io>) {
let simplex = Simplex::new(delay);
let (delay_read, delay_write) = BiLock::new(simplex);
let (io_read, io_write) = BiLock::new(io);
(
Self {
read: delay_read,
write: io_write,
},
DelayFuture {
delay: delay as u64,
read: io_read,
buf: vec![0; BUF_SIZE].into_boxed_slice(),
write: delay_write,
},
)
}
}
impl<Io> AsyncRead for Delay<Io>
where
Io: AsyncRead,
{
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
let mut read = ready!(self.read.poll_lock(cx));
read.poll_read(cx, buf)
}
}
impl<Io> AsyncWrite for Delay<Io>
where
Io: AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let mut write = ready!(self.write.poll_lock(cx));
write.as_pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let mut write = ready!(self.write.poll_lock(cx));
write.as_pin_mut().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let mut write = ready!(self.write.poll_lock(cx));
write.as_pin_mut().poll_close(cx)
}
}
pin_project! {
/// Future returned by [`Delay::new`].
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct DelayFuture<Io> {
delay: u64,
read: BiLock<Io>,
buf: Box<[u8]>,
write: BiLock<Simplex>,
}
}
impl<Io> Future for DelayFuture<Io>
where
Io: AsyncRead,
{
type Output = Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut write = ready!(this.write.poll_lock(cx));
let mut read = ready!(this.read.poll_lock(cx));
let mut len = 0;
let mut closed = false;
while let Poll::Ready(res) = read.as_pin_mut().poll_read(cx, this.buf) {
match res {
Ok(n) => {
if n == 0 {
closed = true;
break;
}
len += n;
write.buf.extend_from_slice(&this.buf[..n]);
}
Err(err) => {
write.close_write();
return Poll::Ready(Err(err));
}
}
}
if len > 0 {
write.packets.push_front(Packet {
len,
ready: Instant::now() + Duration::from_millis(*this.delay),
});
write.wake_reader();
}
if closed {
write.close_write();
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[derive(Debug, Clone, Copy)]
struct Packet {
len: usize,
/// Time when the packet is ready.
ready: Instant,
}
#[derive(Debug)]
struct Simplex {
buf: BytesMut,
/// Packets in the buffer.
packets: VecDeque<Packet>,
/// Whether the write side has closed.
is_closed: bool,
/// Waker for the read side.
read_waker: Option<Waker>,
/// Timer to wake up the read side when the latency has elapsed.
read_timer: Pin<Box<DelayTimer>>,
}
impl Simplex {
fn new(delay: usize) -> Self {
Self {
buf: BytesMut::with_capacity(16 * 1024),
packets: VecDeque::new(),
is_closed: false,
read_waker: None,
read_timer: Box::pin(DelayTimer::new(Duration::from_millis(delay as u64))),
}
}
fn close_write(&mut self) {
self.is_closed = true;
// needs to notify any readers that no more data will come
self.wake_reader();
}
fn wake_reader(&mut self) {
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}
fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
if self.buf.has_remaining() {
// Maximum amount of bytes that can be processed this poll.
let max_len = self.buf.remaining().min(buf.len());
// Read packets in reverse order to process the oldest packets first.
//
// Steps:
// 1. Check that the packet is ready to be read (latency has elapsed). If not,
// register a waker for when it is ready.
// 2. Read as many bytes as possible from the packet. Update the packet length
// if it is partially read.
// 3. Remove fully read packets from the queue.
let mut remaining = max_len;
let mut done_packets = 0;
let now = Instant::now();
for Packet {
len: packet_len,
ready,
} in self.packets.iter_mut().rev()
{
let time_left = ready.saturating_duration_since(now);
if time_left.as_millis() > 0 {
self.read_timer.reset(time_left);
// Poll timer to register waker.
assert!(self.read_timer.as_mut().poll(cx).is_pending());
break;
}
let len = (*packet_len).min(remaining);
if len == *packet_len {
done_packets += 1;
} else {
// Partial read, update packet length.
*packet_len -= len;
}
remaining -= len;
if remaining == 0 {
break;
}
}
if remaining == max_len {
// No packets are ready to be read, so we need to wait for the timer to expire.
return Poll::Pending;
}
// Remove packets that have been fully read.
self.packets.truncate(self.packets.len() - done_packets);
let len = max_len - remaining;
buf[..len].copy_from_slice(&self.buf[..len]);
self.buf.advance(len);
Poll::Ready(Ok(len))
} else if self.is_closed {
Poll::Ready(Ok(0))
} else {
self.read_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use std::{pin::pin, time::Duration};
use super::*;
use futures::{AsyncWriteExt, future::poll_fn, poll};
use futures_plex::simplex;
use mock_instant::thread_local::MockClock;
#[pollster::test]
async fn test_delay() {
let data = b"hello world";
const DELAY: usize = 1;
let (read, mut write) = simplex(100);
let (mut delay, mut fut) = Delay::new(read, DELAY);
write.write_all(data).await.unwrap();
write.flush().await.unwrap();
assert!(poll!(&mut fut).is_pending());
let mut buf = vec![0u8; 11];
let res = poll!(poll_fn(|cx| pin!(&mut delay).poll_read(cx, &mut buf)));
// Data should not be available yet.
assert!(res.is_pending());
MockClock::advance(Duration::from_millis(DELAY as u64));
let res = poll!(poll_fn(|cx| pin!(&mut delay).poll_read(cx, &mut buf)));
// Data should be available now.
assert!(matches!(res, Poll::Ready(Ok(11))));
write.close().await.unwrap();
assert!(poll!(&mut fut).is_ready());
}
}

54
futures-limit/src/lib.rs Normal file
View File

@@ -0,0 +1,54 @@
#![doc = include_str!("../README.md")]
pub(crate) mod bucket;
mod delay;
mod rate;
pub use delay::{Delay, DelayFuture};
pub use rate::Rate;
use futures::{AsyncRead, AsyncWrite};
#[cfg(test)]
pub(crate) use mock_instant::thread_local::Instant;
#[cfg(all(not(test), not(target_arch = "wasm32")))]
pub(crate) use std::time::Instant;
#[cfg(all(not(test), target_arch = "wasm32"))]
pub(crate) use web_time::Instant;
/// Extension trait for `AsyncWrite`.
pub trait AsyncWriteLimitExt: AsyncWrite {
/// Limit the write rate of the underlying writer.
///
/// # Arguments
///
/// * `burst` - Maximum burst size in bits.
/// * `rate` - Maximum write rate in bits per second.
fn limit_rate(self, burst: usize, rate: usize) -> Rate<Self>
where
Self: Sized,
{
Rate::new(self, burst, rate)
}
}
impl<T> AsyncWriteLimitExt for T where T: AsyncWrite {}
/// Extension trait for `AsyncRead`.
pub trait AsyncReadDelayExt: AsyncRead {
/// Delays incoming data by the given amount of milliseconds.
///
/// Returns a future which must be polled continuously. See [`Delay`] for
/// more details.
///
/// # Arguments
///
/// * `delay` - Delay in milliseconds.
fn delay(self, delay: usize) -> (Delay<Self>, DelayFuture<Self>)
where
Self: Sized,
{
Delay::new(self, delay)
}
}
impl<T> AsyncReadDelayExt for T where T: AsyncRead {}

176
futures-limit/src/rate.rs Normal file
View File

@@ -0,0 +1,176 @@
use std::{
io::{IoSliceMut, Result},
pin::Pin,
task::{Context, Poll},
};
use futures::{AsyncRead, AsyncWrite};
use pin_project_lite::pin_project;
use crate::bucket::TokenBucket;
const M: u64 = 1_000_000;
pin_project! {
/// Rate limiting wrapper for `AsyncWrite`.
#[derive(Debug)]
pub struct Rate<Io> {
#[pin] io: Io,
bucket: TokenBucket,
}
}
impl<Io> Rate<Io> {
/// Create a new rate limiter.
///
/// # Arguments
///
/// * `io` - Underlying I/O object.
/// * `burst` - Maximum burst size in bits.
/// * `rate` - Maximum write rate in bits per second.
pub fn new(io: Io, burst: usize, rate: usize) -> Self {
// Bucketing is done with microsecond granularity.
// Each token represents one-millionth of a byte.
let tokens = ((burst as u64) * M).div_ceil(8);
let tokens_per_micro_sec = (rate as u64).div_ceil(8);
let bucket = TokenBucket::new(tokens, tokens_per_micro_sec);
Self { io, bucket }
}
}
impl<Io> Rate<Io>
where
Io: AsyncWrite,
{
fn poll_write_internal(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
let this = self.project();
this.bucket.refill();
let available = (this.bucket.available() / M) as usize;
if available == 0 {
this.bucket.poll_refill(cx);
return Poll::Pending;
}
let len = buf.len().min(available);
let res = this.io.poll_write(cx, &buf[..len]);
if let Poll::Ready(Ok(n)) = &res {
this.bucket.consume((*n as u64) * M);
}
res
}
}
impl<Io> AsyncWrite for Rate<Io>
where
Io: AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
self.poll_write_internal(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.project().io.poll_flush(cx)
}
#[inline]
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.project().io.poll_close(cx)
}
}
impl<Io> AsyncRead for Rate<Io>
where
Io: AsyncRead,
{
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
self.project().io.poll_read(cx, buf)
}
#[inline]
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<Result<usize>> {
self.project().io.poll_read_vectored(cx, bufs)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use futures::{AsyncWriteExt, io::sink, poll};
use mock_instant::thread_local::MockClock;
// Tests that the burst size is respected.
#[pollster::test]
async fn test_rate_burst() {
let data = b"hello world";
let mut io = Rate::new(sink(), (data.len() - 1) * 8, 0);
let n = io.write(data).await.unwrap();
assert_eq!(n, data.len() - 1);
}
// Tests that the burst will allow all data to be written when it is less than
// the burst size.
#[pollster::test]
async fn test_rate_burst_all() {
let data = b"hello world";
let mut io = Rate::new(sink(), data.len() * 8, 0);
let n = io.write(data).await.unwrap();
assert_eq!(n, data.len());
}
#[pollster::test]
async fn test_rate_limit() {
let data = b"hello world";
let mut io = Rate::new(sink(), data.len() * 8, 8);
let n = io.write(data).await.unwrap();
assert_eq!(n, data.len());
let mut write = io.write(data);
assert!(poll!(&mut write).is_pending());
MockClock::advance(Duration::from_secs(1));
let Poll::Ready(Ok(n)) = poll!(write) else {
panic!("poll should be ready");
};
// 1 byte per second.
assert_eq!(n, 1);
let mut write = io.write(data);
assert!(poll!(&mut write).is_pending());
}
}

View File

@@ -0,0 +1,334 @@
use std::{
collections::VecDeque,
future::Future,
io::Result,
pin::Pin,
task::{Context, Poll, Waker},
time::Duration,
};
use bytes::{Buf, BytesMut};
use futures::{
io::{ReadHalf, WriteHalf},
AsyncRead, AsyncReadExt, AsyncWrite,
};
use futures_timer::Delay;
use pin_project_lite::pin_project;
use crate::Instant;
/// Returns a simplex connection pair.
///
/// # Arguments
///
/// * `params` - Parameters for the connection.
pub fn simplex(params: Params) -> (ReadHalf<Simplex>, WriteHalf<Simplex>) {
Simplex::new(params).split()
}
#[derive(Debug, Clone, Copy)]
struct Packet {
len: usize,
/// Time when the packet is ready.
ready: Instant,
}
/// Unidirectional pipe with configurable bandwidth, latency and buffer size.
///
/// Implementation is based on the simplex in `tokio`.
#[derive(Debug)]
pub struct Simplex {
params: Params,
buf: BytesMut,
/// Packets in the buffer.
packets: VecDeque<Packet>,
/// Whether the write side has closed.
is_closed: bool,
/// Waker for the read side.
read_waker: Option<Waker>,
/// Read bucket.
read_bucket: TokenBucket,
/// Timer to wake up the read side when the latency has elapsed.
read_timer: Pin<Box<Delay>>,
/// Waker for the write side.
write_waker: Option<Waker>,
/// Write bucket.
write_bucket: TokenBucket,
}
impl Simplex {
/// Create a new `Simplex`.
///
/// # Panics
///
/// Panics if `tx_rate` or `rx_rate` are less than 8 bits per second.
///
/// # Arguments
///
/// * `params` - Parameters for the connection.
pub fn new(params: Params) -> Self {
assert!(
params.tx_rate >= 8,
"tx_rate must be at least 8 bits per second"
);
assert!(
params.rx_rate >= 8,
"rx_rate must be at least 8 bits per second"
);
let tx_bytes_per_sec = params.tx_rate >> 3;
let rx_bytes_per_sec = params.rx_rate >> 3;
let write_bucket = TokenBucket::new(DEFAULT_BUCKET_CAPACITY, tx_bytes_per_sec >> 20);
let read_bucket = TokenBucket::new(DEFAULT_BUCKET_CAPACITY, rx_bytes_per_sec >> 20);
Self {
params,
buf: BytesMut::new(),
packets: VecDeque::new(),
is_closed: false,
read_waker: None,
read_bucket,
read_timer: Box::pin(Delay::new(Duration::from_millis(0))),
write_waker: None,
write_bucket,
}
}
fn close_write(&mut self) {
self.is_closed = true;
// needs to notify any readers that no more data will come
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}
fn poll_write_internal(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let len = self.params.buf_size - self.buf.len();
if len == 0 {
// Buffer is full, so we need to wait for some data to be read.
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}
self.write_bucket.refill();
let len = len.min(self.write_bucket.available());
if len == 0 {
// No tokens available, so we need to wait for the bucket to refill.
assert!(self.write_bucket.poll_refill(cx).is_pending());
return Poll::Pending;
}
let len = len.min(buf.len());
self.buf.extend_from_slice(&buf[..len]);
self.write_bucket.consume(len);
self.packets.push_front(Packet {
len,
ready: Instant::now() + Duration::from_millis(self.params.latency as u64),
});
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(len))
}
fn poll_read_internal(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
if self.buf.has_remaining() {
self.read_bucket.refill();
if self.read_bucket.is_empty() {
// No tokens available, so we need to wait for the bucket to refill.
assert!(self.read_bucket.poll_refill(cx).is_pending());
return Poll::Pending;
}
// Maximum amount of bytes that can be processed this poll.
let max_len = self
.buf
.remaining()
.min(buf.len())
.min(self.read_bucket.available());
// Read packets in reverse order to process the oldest packets first.
//
// Steps:
// 1. Check that the packet is ready to be read (latency has elapsed). If not,
// register a waker for when it is ready.
// 2. Read as many bytes as possible from the packet. Update the packet length
// if it is partially read.
// 3. Remove fully read packets from the queue.
let mut remaining = max_len;
let mut complete = 0;
let now = Instant::now();
for Packet {
len: packet_len,
ready,
} in self.packets.iter_mut().rev()
{
let time_left = ready.saturating_duration_since(now);
if time_left.as_millis() > 0 {
self.read_timer.reset(time_left);
// Poll timer to register waker.
assert!(self.read_timer.as_mut().poll(cx).is_pending());
break;
}
let len = (*packet_len).min(remaining);
if len == *packet_len {
complete += 1;
} else {
// Partial read, update packet length.
*packet_len -= len;
}
remaining -= len;
if remaining == 0 {
break;
}
}
if remaining == max_len {
// No packets are ready to be read, so we need to wait for the timer to expire.
return Poll::Pending;
}
// Remove packets that have been fully read.
self.packets.truncate(self.packets.len() - complete);
let len = max_len - remaining;
buf[..len].copy_from_slice(&self.buf[..len]);
self.buf.advance(len);
self.read_bucket.consume(len);
if len > 0 {
// The passed `buf` might have been empty, don't wake up if
// no bytes have been moved.
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
Poll::Ready(Ok(len))
} else if self.is_closed {
Poll::Ready(Ok(0))
} else {
self.read_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
impl AsyncWrite for Simplex {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.poll_write_internal(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
self.close_write();
Poll::Ready(Ok(()))
}
}
impl AsyncRead for Simplex {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
self.poll_read_internal(cx, buf)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{poll, AsyncReadExt, AsyncWriteExt};
use mock_instant::thread_local::MockClock;
use pollster::FutureExt;
#[test]
fn test_simplex() {
async {
let mut io = Simplex::new(Params {
buf_size: 1024,
tx_rate: 1024 * 8,
rx_rate: 1024 * 8,
latency: 0,
});
let data = b"hello world";
io.write_all(data).await.unwrap();
let mut buf = [0; 1024];
let len = io.read(&mut buf).await.unwrap();
assert_eq!(len, data.len());
assert_eq!(&buf[..len], data);
}
.block_on()
}
#[test]
fn test_simplex_burst_write() {
async {
let mut io = Simplex::new(Params {
buf_size: DEFAULT_BUCKET_CAPACITY,
tx_rate: 8,
rx_rate: 1024 * 8,
latency: 0,
});
let data = vec![0; DEFAULT_BUCKET_CAPACITY];
// Burst write should accept the full buffer.
assert_eq!(
poll!(io.write(&data)).map(|r| r.unwrap()),
Poll::Ready(data.len())
);
assert_eq!(io.write_bucket.tokens, 0);
}
.block_on()
}
#[test]
fn test_simplex_latency() {
async {
let mut io = Simplex::new(Params {
buf_size: DEFAULT_BUCKET_CAPACITY,
tx_rate: 8,
rx_rate: 1024 * 8,
latency: 2,
});
let mut data = vec![0; DEFAULT_BUCKET_CAPACITY];
io.write_all(&data).await.unwrap();
// No time has elapsed, so no data should be available.
assert_eq!(poll!(io.read(&mut data)).map(|r| r.unwrap()), Poll::Pending);
// Latency still hasn't elapsed.
MockClock::advance(Duration::from_millis(1));
assert_eq!(poll!(io.read(&mut data)).map(|r| r.unwrap()), Poll::Pending);
// Latency has elapsed, so data should be available.
MockClock::advance(Duration::from_millis(1));
assert_eq!(
poll!(io.read(&mut data)).map(|r| r.unwrap()),
Poll::Ready(DEFAULT_BUCKET_CAPACITY)
);
}
.block_on()
}
}

10
futures-plex/Cargo.toml Normal file
View File

@@ -0,0 +1,10 @@
[package]
name = "futures-plex"
version = "0.1.0"
edition = "2024"
description = "Port of tokio's `SimplexStream` and `DuplexStream` for the `futures` ecosystem."
[dependencies]
bytes = { version = "1" }
futures-io = { version = "0.3" }
futures-util = { version = "0.3", default-features = false, features = ["io"] }

5
futures-plex/README.md Normal file
View File

@@ -0,0 +1,5 @@
# futures-plex
Port of tokio's `SimplexStream` and `DuplexStream` for the `futures` ecosystem.
This crate provides in-memory implementations for `AsyncRead` and `AsyncWrite`.

344
futures-plex/src/lib.rs Normal file
View File

@@ -0,0 +1,344 @@
#![doc = include_str!("../README.md")]
use std::{
pin::Pin,
task::{self, Poll, Waker},
};
use bytes::{Buf, BytesMut};
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::{
AsyncReadExt,
io::{ReadHalf, WriteHalf},
};
/// A bidirectional pipe to read and write bytes in memory.
///
/// A pair of `DuplexStream`s are created together, and they act as a "channel"
/// that can be used as in-memory IO types. Writing to one of the pairs will
/// allow that data to be read from the other, and vice versa.
///
/// # Closing a `DuplexStream`
///
/// If one end of the `DuplexStream` channel is dropped, any pending reads on
/// the other side will continue to read data until the buffer is drained, then
/// they will signal EOF by returning 0 bytes. Any writes to the other side,
/// including pending ones (that are waiting for free space in the buffer) will
/// return `Err(BrokenPipe)` immediately.
///
/// # Example
///
/// ```
/// # async fn ex() -> std::io::Result<()> {
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
/// let (mut client, mut server) = futures_plex::duplex(64);
///
/// client.write_all(b"ping").await?;
///
/// let mut buf = [0u8; 4];
/// server.read_exact(&mut buf).await?;
/// assert_eq!(&buf, b"ping");
///
/// server.write_all(b"pong").await?;
///
/// client.read_exact(&mut buf).await?;
/// assert_eq!(&buf, b"pong");
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
pub struct DuplexStream {
read: ReadHalf<SimplexStream>,
write: WriteHalf<SimplexStream>,
}
/// A unidirectional pipe to read and write bytes in memory.
///
/// It can be constructed by [`simplex`] function which will create a pair of
/// reader and writer or by calling [`SimplexStream::new_unsplit`] that will
/// create a handle for both reading and writing.
///
/// # Example
///
/// ```
/// # async fn ex() -> std::io::Result<()> {
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
/// let (mut receiver, mut sender) = futures_plex::simplex(64);
///
/// sender.write_all(b"ping").await?;
///
/// let mut buf = [0u8; 4];
/// receiver.read_exact(&mut buf).await?;
/// assert_eq!(&buf, b"ping");
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
pub struct SimplexStream {
/// The buffer storing the bytes written, also read from.
///
/// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
/// functionality already. Additionally, it can try to copy data in the
/// same buffer if there read index has advanced far enough.
buffer: BytesMut,
/// Determines if the write side has been closed.
is_closed: bool,
/// The maximum amount of bytes that can be written before returning
/// `Poll::Pending`.
max_buf_size: usize,
/// If the `read` side has been polled and is pending, this is the waker
/// for that parked task.
read_waker: Option<Waker>,
/// If the `write` side has filled the `max_buf_size` and returned
/// `Poll::Pending`, this is the waker for that parked task.
write_waker: Option<Waker>,
}
// ===== impl DuplexStream =====
/// Create a new pair of `DuplexStream`s that act like a pair of connected
/// sockets.
///
/// The `max_buf_size` argument is the maximum amount of bytes that can be
/// written to a side before the write returns `Poll::Pending`.
pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
let (read_0, write_0) = SimplexStream::new_unsplit(max_buf_size).split();
let (read_1, write_1) = SimplexStream::new_unsplit(max_buf_size).split();
(
DuplexStream {
read: read_0,
write: write_1,
},
DuplexStream {
read: read_1,
write: write_0,
},
)
}
impl AsyncRead for DuplexStream {
// Previous rustc required this `self` to be `mut`, even though newer
// versions recognize it isn't needed to call `lock()`. So for
// compatibility, we include the `mut` and `allow` the lint.
//
// See https://github.com/rust-lang/rust/issues/73592
#[allow(unused_mut)]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.read).poll_read(cx, buf)
}
}
impl AsyncWrite for DuplexStream {
#[allow(unused_mut)]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.write).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.write).poll_write_vectored(cx, bufs)
}
#[allow(unused_mut)]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.write).poll_flush(cx)
}
#[allow(unused_mut)]
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.write).poll_close(cx)
}
}
// ===== impl SimplexStream =====
/// Creates unidirectional buffer that acts like in memory pipe.
///
/// The `max_buf_size` argument is the maximum amount of bytes that can be
/// written to a buffer before the it returns `Poll::Pending`.
///
/// # Reunite reader and writer
///
/// The reader and writer half can be unified into a single structure
/// of `SimplexStream` that supports both reading and writing or
/// the `SimplexStream` can be already created as unified structure
/// using [`SimplexStream::new_unsplit()`].
///
/// ```
/// # async fn ex() -> std::io::Result<()> {
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
/// let (reader, writer) = futures_plex::simplex(64);
/// let mut simplex_stream = reader.reunite(writer).unwrap();
/// simplex_stream.write_all(b"hello").await?;
///
/// let mut buf = [0u8; 5];
/// simplex_stream.read_exact(&mut buf).await?;
/// assert_eq!(&buf, b"hello");
/// # Ok(())
/// # }
/// ```
pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
SimplexStream::new_unsplit(max_buf_size).split()
}
impl SimplexStream {
/// Creates unidirectional buffer that acts like in memory pipe. To create
/// split version with separate reader and writer you can use
/// [`simplex`] function.
///
/// The `max_buf_size` argument is the maximum amount of bytes that can be
/// written to a buffer before the it returns `Poll::Pending`.
pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
SimplexStream {
buffer: BytesMut::new(),
is_closed: false,
max_buf_size,
read_waker: None,
write_waker: None,
}
}
fn close_write(&mut self) {
self.is_closed = true;
// needs to notify any readers that no more data will come
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}
fn poll_read_internal(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
if self.buffer.has_remaining() {
let len = self.buffer.remaining().min(buf.len());
buf[..len].copy_from_slice(&self.buffer[..len]);
self.buffer.advance(len);
if len > 0 {
// The passed `buf` might have been empty, don't wake up if
// no bytes have been moved.
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
Poll::Ready(Ok(len))
} else if self.is_closed {
Poll::Ready(Ok(0))
} else {
self.read_waker = Some(cx.waker().clone());
Poll::Pending
}
}
fn poll_write_internal(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let avail = self.max_buf_size - self.buffer.len();
if avail == 0 {
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}
let len = buf.len().min(avail);
self.buffer.extend_from_slice(&buf[..len]);
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(len))
}
fn poll_write_vectored_internal(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let avail = self.max_buf_size - self.buffer.len();
if avail == 0 {
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}
let mut rem = avail;
for buf in bufs {
if rem == 0 {
break;
}
let len = buf.len().min(rem);
self.buffer.extend_from_slice(&buf[..len]);
rem -= len;
}
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(avail - rem))
}
}
impl AsyncRead for SimplexStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.poll_read_internal(cx, buf)
}
}
impl AsyncWrite for SimplexStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.poll_write_internal(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.poll_write_vectored_internal(cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.close_write();
Poll::Ready(Ok(()))
}
}

View File

@@ -1,2 +1,2 @@
imports_granularity = "Crate"
wrap_comments = true
wrap_comments = true