mirror of
https://github.com/tlsnotary/tlsn-utils.git
synced 2026-01-06 19:33:55 -05:00
feat: futures-limit and futures-plex (#51)
* feat: futures-limit and futures-plex * use web time for wasm
This commit is contained in:
@@ -7,10 +7,13 @@ members = [
|
||||
"utils-aio",
|
||||
"utils/fuzz",
|
||||
"websocket-relay",
|
||||
"futures-limit",
|
||||
"futures-plex",
|
||||
"web-spawn",
|
||||
]
|
||||
|
||||
[workspace.dependencies]
|
||||
futures-plex = { path = "futures-plex" }
|
||||
serio = { path = "serio" }
|
||||
spansy = { path = "spansy" }
|
||||
tlsn-utils = { path = "utils" }
|
||||
@@ -23,6 +26,7 @@ async-tungstenite = "0.16"
|
||||
bincode = "1.3"
|
||||
bytes = "1"
|
||||
cfg-if = "1"
|
||||
criterion = "0.5"
|
||||
futures = "0.3"
|
||||
futures-channel = "0.3"
|
||||
futures-core = "0.3"
|
||||
@@ -30,6 +34,7 @@ futures-io = "0.3"
|
||||
futures-sink = "0.3"
|
||||
futures-util = "0.3"
|
||||
pin-project-lite = "0.2"
|
||||
pollster = "0.4"
|
||||
prost = "0.9"
|
||||
prost-build = "0.9"
|
||||
rand = "0.8"
|
||||
|
||||
24
futures-limit/Cargo.toml
Normal file
24
futures-limit/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "futures-limit"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
bytes = { workspace = true }
|
||||
futures = { workspace = true, features = ["bilock", "unstable"] }
|
||||
futures-timer = { version = "3" }
|
||||
pin-project-lite = { workspace = true }
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
futures-timer = { version = "3", features = ["wasm-bindgen"] }
|
||||
web-time = { version = "1.1" }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { workspace = true }
|
||||
pollster = { workspace = true, features = ["macro"] }
|
||||
mock_instant = "0.5"
|
||||
futures-plex = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "bench"
|
||||
harness = false
|
||||
3
futures-limit/README.md
Normal file
3
futures-limit/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# futures-limit
|
||||
|
||||
This crate provides a rate limiting wrapper for `AsyncWrite` and a delay wrapper for `AsyncRead`.
|
||||
56
futures-limit/benches/bench.rs
Normal file
56
futures-limit/benches/bench.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use criterion::{BenchmarkId, Criterion, Throughput, black_box, criterion_group, criterion_main};
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use futures_limit::AsyncWriteLimitExt;
|
||||
use futures_plex::simplex;
|
||||
use pollster::FutureExt as _;
|
||||
|
||||
const M: usize = 1 << 20;
|
||||
|
||||
pub fn criterion_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("rate");
|
||||
|
||||
group.throughput(Throughput::Bytes(M as u64));
|
||||
group.bench_function("max", |b| {
|
||||
let (mut rx, tx) = simplex(M);
|
||||
let mut tx = tx.limit_rate(8 * M, usize::MAX);
|
||||
let tx_buf = vec![0; M];
|
||||
let mut rx_buf = vec![0; M];
|
||||
|
||||
b.iter(|| {
|
||||
async {
|
||||
futures::try_join!(tx.write_all(&tx_buf), rx.read_exact(&mut rx_buf)).unwrap();
|
||||
}
|
||||
.block_on();
|
||||
black_box(&rx_buf);
|
||||
});
|
||||
});
|
||||
|
||||
for mega_bits_per_sec in [10, 100, 1000] {
|
||||
// 1 ms of data.
|
||||
let size = mega_bits_per_sec * M / 1000 / 8;
|
||||
|
||||
group.throughput(Throughput::Bytes(size as u64));
|
||||
group.bench_function(BenchmarkId::from_parameter(mega_bits_per_sec), |b| {
|
||||
let (mut rx, tx) = simplex(M);
|
||||
|
||||
// 2ms burst
|
||||
let burst = mega_bits_per_sec * M / 500;
|
||||
|
||||
let mut tx = tx.limit_rate(burst, mega_bits_per_sec * M);
|
||||
|
||||
let tx_buf = vec![0; size];
|
||||
let mut rx_buf = vec![0; size];
|
||||
|
||||
b.iter(|| {
|
||||
async {
|
||||
futures::try_join!(tx.write_all(&tx_buf), rx.read_exact(&mut rx_buf)).unwrap();
|
||||
}
|
||||
.block_on();
|
||||
black_box(&rx_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
63
futures-limit/src/bucket.rs
Normal file
63
futures-limit/src/bucket.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use std::{future::Future, pin::Pin, task::Context, time::Duration};
|
||||
|
||||
use futures_timer::Delay;
|
||||
|
||||
use crate::Instant;
|
||||
|
||||
/// Default interval in millis in which the write side is woken up when
|
||||
/// reaching throughput limits. This sets the granularity of the rate limiting
|
||||
/// and an upper bound on the throughput.
|
||||
const WAKE_INTERVAL: u64 = 1;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TokenBucket {
|
||||
capacity: u64,
|
||||
tokens: u64,
|
||||
/// Refill rate in tokens per micro second.
|
||||
rate: u64,
|
||||
last_refill: Instant,
|
||||
timer: Pin<Box<Delay>>,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
/// Create a new `TokenBucket`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `capacity` - Maximum number of tokens the bucket can hold.
|
||||
/// * `rate` - Refill rate in tokens per microsecond.
|
||||
pub(crate) fn new(capacity: u64, rate: u64) -> Self {
|
||||
Self {
|
||||
capacity,
|
||||
tokens: capacity,
|
||||
rate,
|
||||
last_refill: Instant::now(),
|
||||
timer: Box::pin(Delay::new(Duration::from_millis(WAKE_INTERVAL))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn available(&self) -> u64 {
|
||||
self.tokens
|
||||
}
|
||||
|
||||
pub(crate) fn consume(&mut self, amount: u64) {
|
||||
self.tokens = self.tokens.saturating_sub(amount);
|
||||
}
|
||||
|
||||
pub(crate) fn poll_refill(&mut self, cx: &mut Context<'_>) {
|
||||
self.timer.reset(Duration::from_millis(WAKE_INTERVAL));
|
||||
assert!(self.timer.as_mut().poll(cx).is_pending());
|
||||
}
|
||||
|
||||
pub(crate) fn refill(&mut self) {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(self.last_refill).as_micros() as u64;
|
||||
if elapsed == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let tokens = elapsed.saturating_mul(self.rate);
|
||||
self.tokens = self.tokens.saturating_add(tokens).min(self.capacity);
|
||||
self.last_refill = now;
|
||||
}
|
||||
}
|
||||
307
futures-limit/src/delay.rs
Normal file
307
futures-limit/src/delay.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
io::Result,
|
||||
pin::Pin,
|
||||
task::{Context, Poll, Waker, ready},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures::{AsyncRead, AsyncWrite, lock::BiLock};
|
||||
use futures_timer::Delay as DelayTimer;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use crate::Instant;
|
||||
|
||||
const BUF_SIZE: usize = 16 * 1024; // 16 KiB
|
||||
|
||||
/// Delay wrapper for `AsyncRead`.
|
||||
///
|
||||
/// This wrapper will delay incoming data by the provided amount of
|
||||
/// milliseconds. A corresponding future is also returned. This future should be
|
||||
/// spawned onto a dedicated thread to ensure that the delay is accurate.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// Incoming data is continuously read from the underlying I/O object. This
|
||||
/// buffer will continue to grow unbounded if the data is processed slower than
|
||||
/// it is received.
|
||||
#[derive(Debug)]
|
||||
pub struct Delay<Io> {
|
||||
read: BiLock<Simplex>,
|
||||
write: BiLock<Io>,
|
||||
}
|
||||
|
||||
impl<Io> Delay<Io> {
|
||||
/// Create a new delay.
|
||||
///
|
||||
/// Returns a future which must be polled continuously. This future should
|
||||
/// be spawned onto a dedicated thread to ensure that the delay is accurate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `io` - Underlying I/O object.
|
||||
/// * `delay` - Delay in milliseconds.
|
||||
pub fn new(io: Io, delay: usize) -> (Self, DelayFuture<Io>) {
|
||||
let simplex = Simplex::new(delay);
|
||||
|
||||
let (delay_read, delay_write) = BiLock::new(simplex);
|
||||
let (io_read, io_write) = BiLock::new(io);
|
||||
|
||||
(
|
||||
Self {
|
||||
read: delay_read,
|
||||
write: io_write,
|
||||
},
|
||||
DelayFuture {
|
||||
delay: delay as u64,
|
||||
read: io_read,
|
||||
buf: vec![0; BUF_SIZE].into_boxed_slice(),
|
||||
write: delay_write,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> AsyncRead for Delay<Io>
|
||||
where
|
||||
Io: AsyncRead,
|
||||
{
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
let mut read = ready!(self.read.poll_lock(cx));
|
||||
read.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> AsyncWrite for Delay<Io>
|
||||
where
|
||||
Io: AsyncWrite,
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
let mut write = ready!(self.write.poll_lock(cx));
|
||||
write.as_pin_mut().poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let mut write = ready!(self.write.poll_lock(cx));
|
||||
write.as_pin_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let mut write = ready!(self.write.poll_lock(cx));
|
||||
write.as_pin_mut().poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Future returned by [`Delay::new`].
|
||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||
pub struct DelayFuture<Io> {
|
||||
delay: u64,
|
||||
read: BiLock<Io>,
|
||||
buf: Box<[u8]>,
|
||||
write: BiLock<Simplex>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Future for DelayFuture<Io>
|
||||
where
|
||||
Io: AsyncRead,
|
||||
{
|
||||
type Output = Result<()>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
let mut write = ready!(this.write.poll_lock(cx));
|
||||
let mut read = ready!(this.read.poll_lock(cx));
|
||||
|
||||
let mut len = 0;
|
||||
let mut closed = false;
|
||||
while let Poll::Ready(res) = read.as_pin_mut().poll_read(cx, this.buf) {
|
||||
match res {
|
||||
Ok(n) => {
|
||||
if n == 0 {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
len += n;
|
||||
write.buf.extend_from_slice(&this.buf[..n]);
|
||||
}
|
||||
Err(err) => {
|
||||
write.close_write();
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len > 0 {
|
||||
write.packets.push_front(Packet {
|
||||
len,
|
||||
ready: Instant::now() + Duration::from_millis(*this.delay),
|
||||
});
|
||||
write.wake_reader();
|
||||
}
|
||||
|
||||
if closed {
|
||||
write.close_write();
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Packet {
|
||||
len: usize,
|
||||
/// Time when the packet is ready.
|
||||
ready: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Simplex {
|
||||
buf: BytesMut,
|
||||
/// Packets in the buffer.
|
||||
packets: VecDeque<Packet>,
|
||||
/// Whether the write side has closed.
|
||||
is_closed: bool,
|
||||
/// Waker for the read side.
|
||||
read_waker: Option<Waker>,
|
||||
/// Timer to wake up the read side when the latency has elapsed.
|
||||
read_timer: Pin<Box<DelayTimer>>,
|
||||
}
|
||||
|
||||
impl Simplex {
|
||||
fn new(delay: usize) -> Self {
|
||||
Self {
|
||||
buf: BytesMut::with_capacity(16 * 1024),
|
||||
packets: VecDeque::new(),
|
||||
is_closed: false,
|
||||
read_waker: None,
|
||||
read_timer: Box::pin(DelayTimer::new(Duration::from_millis(delay as u64))),
|
||||
}
|
||||
}
|
||||
|
||||
fn close_write(&mut self) {
|
||||
self.is_closed = true;
|
||||
// needs to notify any readers that no more data will come
|
||||
self.wake_reader();
|
||||
}
|
||||
|
||||
fn wake_reader(&mut self) {
|
||||
if let Some(waker) = self.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
|
||||
if self.buf.has_remaining() {
|
||||
// Maximum amount of bytes that can be processed this poll.
|
||||
let max_len = self.buf.remaining().min(buf.len());
|
||||
|
||||
// Read packets in reverse order to process the oldest packets first.
|
||||
//
|
||||
// Steps:
|
||||
// 1. Check that the packet is ready to be read (latency has elapsed). If not,
|
||||
// register a waker for when it is ready.
|
||||
// 2. Read as many bytes as possible from the packet. Update the packet length
|
||||
// if it is partially read.
|
||||
// 3. Remove fully read packets from the queue.
|
||||
let mut remaining = max_len;
|
||||
let mut done_packets = 0;
|
||||
let now = Instant::now();
|
||||
for Packet {
|
||||
len: packet_len,
|
||||
ready,
|
||||
} in self.packets.iter_mut().rev()
|
||||
{
|
||||
let time_left = ready.saturating_duration_since(now);
|
||||
if time_left.as_millis() > 0 {
|
||||
self.read_timer.reset(time_left);
|
||||
// Poll timer to register waker.
|
||||
assert!(self.read_timer.as_mut().poll(cx).is_pending());
|
||||
break;
|
||||
}
|
||||
|
||||
let len = (*packet_len).min(remaining);
|
||||
if len == *packet_len {
|
||||
done_packets += 1;
|
||||
} else {
|
||||
// Partial read, update packet length.
|
||||
*packet_len -= len;
|
||||
}
|
||||
|
||||
remaining -= len;
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if remaining == max_len {
|
||||
// No packets are ready to be read, so we need to wait for the timer to expire.
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
// Remove packets that have been fully read.
|
||||
self.packets.truncate(self.packets.len() - done_packets);
|
||||
|
||||
let len = max_len - remaining;
|
||||
buf[..len].copy_from_slice(&self.buf[..len]);
|
||||
self.buf.advance(len);
|
||||
|
||||
Poll::Ready(Ok(len))
|
||||
} else if self.is_closed {
|
||||
Poll::Ready(Ok(0))
|
||||
} else {
|
||||
self.read_waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{pin::pin, time::Duration};
|
||||
|
||||
use super::*;
|
||||
use futures::{AsyncWriteExt, future::poll_fn, poll};
|
||||
use futures_plex::simplex;
|
||||
use mock_instant::thread_local::MockClock;
|
||||
|
||||
#[pollster::test]
|
||||
async fn test_delay() {
|
||||
let data = b"hello world";
|
||||
const DELAY: usize = 1;
|
||||
|
||||
let (read, mut write) = simplex(100);
|
||||
let (mut delay, mut fut) = Delay::new(read, DELAY);
|
||||
|
||||
write.write_all(data).await.unwrap();
|
||||
write.flush().await.unwrap();
|
||||
|
||||
assert!(poll!(&mut fut).is_pending());
|
||||
|
||||
let mut buf = vec![0u8; 11];
|
||||
let res = poll!(poll_fn(|cx| pin!(&mut delay).poll_read(cx, &mut buf)));
|
||||
|
||||
// Data should not be available yet.
|
||||
assert!(res.is_pending());
|
||||
|
||||
MockClock::advance(Duration::from_millis(DELAY as u64));
|
||||
|
||||
let res = poll!(poll_fn(|cx| pin!(&mut delay).poll_read(cx, &mut buf)));
|
||||
|
||||
// Data should be available now.
|
||||
assert!(matches!(res, Poll::Ready(Ok(11))));
|
||||
|
||||
write.close().await.unwrap();
|
||||
|
||||
assert!(poll!(&mut fut).is_ready());
|
||||
}
|
||||
}
|
||||
54
futures-limit/src/lib.rs
Normal file
54
futures-limit/src/lib.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
#![doc = include_str!("../README.md")]
|
||||
|
||||
pub(crate) mod bucket;
|
||||
mod delay;
|
||||
mod rate;
|
||||
|
||||
pub use delay::{Delay, DelayFuture};
|
||||
pub use rate::Rate;
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
#[cfg(test)]
|
||||
pub(crate) use mock_instant::thread_local::Instant;
|
||||
#[cfg(all(not(test), not(target_arch = "wasm32")))]
|
||||
pub(crate) use std::time::Instant;
|
||||
#[cfg(all(not(test), target_arch = "wasm32"))]
|
||||
pub(crate) use web_time::Instant;
|
||||
|
||||
/// Extension trait for `AsyncWrite`.
|
||||
pub trait AsyncWriteLimitExt: AsyncWrite {
|
||||
/// Limit the write rate of the underlying writer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `burst` - Maximum burst size in bits.
|
||||
/// * `rate` - Maximum write rate in bits per second.
|
||||
fn limit_rate(self, burst: usize, rate: usize) -> Rate<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Rate::new(self, burst, rate)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncWriteLimitExt for T where T: AsyncWrite {}
|
||||
|
||||
/// Extension trait for `AsyncRead`.
|
||||
pub trait AsyncReadDelayExt: AsyncRead {
|
||||
/// Delays incoming data by the given amount of milliseconds.
|
||||
///
|
||||
/// Returns a future which must be polled continuously. See [`Delay`] for
|
||||
/// more details.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `delay` - Delay in milliseconds.
|
||||
fn delay(self, delay: usize) -> (Delay<Self>, DelayFuture<Self>)
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Delay::new(self, delay)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncReadDelayExt for T where T: AsyncRead {}
|
||||
176
futures-limit/src/rate.rs
Normal file
176
futures-limit/src/rate.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use std::{
|
||||
io::{IoSliceMut, Result},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use crate::bucket::TokenBucket;
|
||||
|
||||
const M: u64 = 1_000_000;
|
||||
|
||||
pin_project! {
|
||||
/// Rate limiting wrapper for `AsyncWrite`.
|
||||
#[derive(Debug)]
|
||||
pub struct Rate<Io> {
|
||||
#[pin] io: Io,
|
||||
bucket: TokenBucket,
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Rate<Io> {
|
||||
/// Create a new rate limiter.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `io` - Underlying I/O object.
|
||||
/// * `burst` - Maximum burst size in bits.
|
||||
/// * `rate` - Maximum write rate in bits per second.
|
||||
pub fn new(io: Io, burst: usize, rate: usize) -> Self {
|
||||
// Bucketing is done with microsecond granularity.
|
||||
// Each token represents one-millionth of a byte.
|
||||
let tokens = ((burst as u64) * M).div_ceil(8);
|
||||
let tokens_per_micro_sec = (rate as u64).div_ceil(8);
|
||||
|
||||
let bucket = TokenBucket::new(tokens, tokens_per_micro_sec);
|
||||
|
||||
Self { io, bucket }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Rate<Io>
|
||||
where
|
||||
Io: AsyncWrite,
|
||||
{
|
||||
fn poll_write_internal(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
let this = self.project();
|
||||
|
||||
this.bucket.refill();
|
||||
|
||||
let available = (this.bucket.available() / M) as usize;
|
||||
if available == 0 {
|
||||
this.bucket.poll_refill(cx);
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
let len = buf.len().min(available);
|
||||
|
||||
let res = this.io.poll_write(cx, &buf[..len]);
|
||||
|
||||
if let Poll::Ready(Ok(n)) = &res {
|
||||
this.bucket.consume((*n as u64) * M);
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> AsyncWrite for Rate<Io>
|
||||
where
|
||||
Io: AsyncWrite,
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
self.poll_write_internal(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
self.project().io.poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
self.project().io.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> AsyncRead for Rate<Io>
|
||||
where
|
||||
Io: AsyncRead,
|
||||
{
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
self.project().io.poll_read(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_read_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &mut [IoSliceMut<'_>],
|
||||
) -> Poll<Result<usize>> {
|
||||
self.project().io.poll_read_vectored(cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
use futures::{AsyncWriteExt, io::sink, poll};
|
||||
use mock_instant::thread_local::MockClock;
|
||||
|
||||
// Tests that the burst size is respected.
|
||||
#[pollster::test]
|
||||
async fn test_rate_burst() {
|
||||
let data = b"hello world";
|
||||
|
||||
let mut io = Rate::new(sink(), (data.len() - 1) * 8, 0);
|
||||
|
||||
let n = io.write(data).await.unwrap();
|
||||
|
||||
assert_eq!(n, data.len() - 1);
|
||||
}
|
||||
|
||||
// Tests that the burst will allow all data to be written when it is less than
|
||||
// the burst size.
|
||||
#[pollster::test]
|
||||
async fn test_rate_burst_all() {
|
||||
let data = b"hello world";
|
||||
|
||||
let mut io = Rate::new(sink(), data.len() * 8, 0);
|
||||
|
||||
let n = io.write(data).await.unwrap();
|
||||
|
||||
assert_eq!(n, data.len());
|
||||
}
|
||||
|
||||
#[pollster::test]
|
||||
async fn test_rate_limit() {
|
||||
let data = b"hello world";
|
||||
|
||||
let mut io = Rate::new(sink(), data.len() * 8, 8);
|
||||
|
||||
let n = io.write(data).await.unwrap();
|
||||
|
||||
assert_eq!(n, data.len());
|
||||
|
||||
let mut write = io.write(data);
|
||||
|
||||
assert!(poll!(&mut write).is_pending());
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
let Poll::Ready(Ok(n)) = poll!(write) else {
|
||||
panic!("poll should be ready");
|
||||
};
|
||||
|
||||
// 1 byte per second.
|
||||
assert_eq!(n, 1);
|
||||
|
||||
let mut write = io.write(data);
|
||||
|
||||
assert!(poll!(&mut write).is_pending());
|
||||
}
|
||||
}
|
||||
334
futures-limit/src/simplex.rs
Normal file
334
futures-limit/src/simplex.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
future::Future,
|
||||
io::Result,
|
||||
pin::Pin,
|
||||
task::{Context, Poll, Waker},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures::{
|
||||
io::{ReadHalf, WriteHalf},
|
||||
AsyncRead, AsyncReadExt, AsyncWrite,
|
||||
};
|
||||
use futures_timer::Delay;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use crate::Instant;
|
||||
|
||||
/// Returns a simplex connection pair.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `params` - Parameters for the connection.
|
||||
pub fn simplex(params: Params) -> (ReadHalf<Simplex>, WriteHalf<Simplex>) {
|
||||
Simplex::new(params).split()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Packet {
|
||||
len: usize,
|
||||
/// Time when the packet is ready.
|
||||
ready: Instant,
|
||||
}
|
||||
|
||||
/// Unidirectional pipe with configurable bandwidth, latency and buffer size.
|
||||
///
|
||||
/// Implementation is based on the simplex in `tokio`.
|
||||
#[derive(Debug)]
|
||||
pub struct Simplex {
|
||||
params: Params,
|
||||
buf: BytesMut,
|
||||
/// Packets in the buffer.
|
||||
packets: VecDeque<Packet>,
|
||||
/// Whether the write side has closed.
|
||||
is_closed: bool,
|
||||
/// Waker for the read side.
|
||||
read_waker: Option<Waker>,
|
||||
/// Read bucket.
|
||||
read_bucket: TokenBucket,
|
||||
/// Timer to wake up the read side when the latency has elapsed.
|
||||
read_timer: Pin<Box<Delay>>,
|
||||
/// Waker for the write side.
|
||||
write_waker: Option<Waker>,
|
||||
/// Write bucket.
|
||||
write_bucket: TokenBucket,
|
||||
}
|
||||
|
||||
impl Simplex {
|
||||
/// Create a new `Simplex`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `tx_rate` or `rx_rate` are less than 8 bits per second.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `params` - Parameters for the connection.
|
||||
pub fn new(params: Params) -> Self {
|
||||
assert!(
|
||||
params.tx_rate >= 8,
|
||||
"tx_rate must be at least 8 bits per second"
|
||||
);
|
||||
assert!(
|
||||
params.rx_rate >= 8,
|
||||
"rx_rate must be at least 8 bits per second"
|
||||
);
|
||||
|
||||
let tx_bytes_per_sec = params.tx_rate >> 3;
|
||||
let rx_bytes_per_sec = params.rx_rate >> 3;
|
||||
|
||||
let write_bucket = TokenBucket::new(DEFAULT_BUCKET_CAPACITY, tx_bytes_per_sec >> 20);
|
||||
let read_bucket = TokenBucket::new(DEFAULT_BUCKET_CAPACITY, rx_bytes_per_sec >> 20);
|
||||
|
||||
Self {
|
||||
params,
|
||||
buf: BytesMut::new(),
|
||||
packets: VecDeque::new(),
|
||||
is_closed: false,
|
||||
read_waker: None,
|
||||
read_bucket,
|
||||
read_timer: Box::pin(Delay::new(Duration::from_millis(0))),
|
||||
write_waker: None,
|
||||
write_bucket,
|
||||
}
|
||||
}
|
||||
|
||||
fn close_write(&mut self) {
|
||||
self.is_closed = true;
|
||||
// needs to notify any readers that no more data will come
|
||||
if let Some(waker) = self.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_internal(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
if self.is_closed {
|
||||
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
|
||||
}
|
||||
|
||||
let len = self.params.buf_size - self.buf.len();
|
||||
if len == 0 {
|
||||
// Buffer is full, so we need to wait for some data to be read.
|
||||
self.write_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
self.write_bucket.refill();
|
||||
let len = len.min(self.write_bucket.available());
|
||||
if len == 0 {
|
||||
// No tokens available, so we need to wait for the bucket to refill.
|
||||
assert!(self.write_bucket.poll_refill(cx).is_pending());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
let len = len.min(buf.len());
|
||||
self.buf.extend_from_slice(&buf[..len]);
|
||||
self.write_bucket.consume(len);
|
||||
self.packets.push_front(Packet {
|
||||
len,
|
||||
ready: Instant::now() + Duration::from_millis(self.params.latency as u64),
|
||||
});
|
||||
|
||||
if let Some(waker) = self.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(len))
|
||||
}
|
||||
|
||||
fn poll_read_internal(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
|
||||
if self.buf.has_remaining() {
|
||||
self.read_bucket.refill();
|
||||
if self.read_bucket.is_empty() {
|
||||
// No tokens available, so we need to wait for the bucket to refill.
|
||||
assert!(self.read_bucket.poll_refill(cx).is_pending());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
// Maximum amount of bytes that can be processed this poll.
|
||||
let max_len = self
|
||||
.buf
|
||||
.remaining()
|
||||
.min(buf.len())
|
||||
.min(self.read_bucket.available());
|
||||
|
||||
// Read packets in reverse order to process the oldest packets first.
|
||||
//
|
||||
// Steps:
|
||||
// 1. Check that the packet is ready to be read (latency has elapsed). If not,
|
||||
// register a waker for when it is ready.
|
||||
// 2. Read as many bytes as possible from the packet. Update the packet length
|
||||
// if it is partially read.
|
||||
// 3. Remove fully read packets from the queue.
|
||||
let mut remaining = max_len;
|
||||
let mut complete = 0;
|
||||
let now = Instant::now();
|
||||
for Packet {
|
||||
len: packet_len,
|
||||
ready,
|
||||
} in self.packets.iter_mut().rev()
|
||||
{
|
||||
let time_left = ready.saturating_duration_since(now);
|
||||
if time_left.as_millis() > 0 {
|
||||
self.read_timer.reset(time_left);
|
||||
// Poll timer to register waker.
|
||||
assert!(self.read_timer.as_mut().poll(cx).is_pending());
|
||||
break;
|
||||
}
|
||||
|
||||
let len = (*packet_len).min(remaining);
|
||||
if len == *packet_len {
|
||||
complete += 1;
|
||||
} else {
|
||||
// Partial read, update packet length.
|
||||
*packet_len -= len;
|
||||
}
|
||||
|
||||
remaining -= len;
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if remaining == max_len {
|
||||
// No packets are ready to be read, so we need to wait for the timer to expire.
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
// Remove packets that have been fully read.
|
||||
self.packets.truncate(self.packets.len() - complete);
|
||||
|
||||
let len = max_len - remaining;
|
||||
buf[..len].copy_from_slice(&self.buf[..len]);
|
||||
self.buf.advance(len);
|
||||
self.read_bucket.consume(len);
|
||||
if len > 0 {
|
||||
// The passed `buf` might have been empty, don't wake up if
|
||||
// no bytes have been moved.
|
||||
if let Some(waker) = self.write_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(len))
|
||||
} else if self.is_closed {
|
||||
Poll::Ready(Ok(0))
|
||||
} else {
|
||||
self.read_waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Simplex {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
self.poll_write_internal(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
self.close_write();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Simplex {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
self.poll_read_internal(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::{poll, AsyncReadExt, AsyncWriteExt};
|
||||
use mock_instant::thread_local::MockClock;
|
||||
use pollster::FutureExt;
|
||||
|
||||
#[test]
|
||||
fn test_simplex() {
|
||||
async {
|
||||
let mut io = Simplex::new(Params {
|
||||
buf_size: 1024,
|
||||
tx_rate: 1024 * 8,
|
||||
rx_rate: 1024 * 8,
|
||||
latency: 0,
|
||||
});
|
||||
|
||||
let data = b"hello world";
|
||||
io.write_all(data).await.unwrap();
|
||||
|
||||
let mut buf = [0; 1024];
|
||||
let len = io.read(&mut buf).await.unwrap();
|
||||
|
||||
assert_eq!(len, data.len());
|
||||
assert_eq!(&buf[..len], data);
|
||||
}
|
||||
.block_on()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplex_burst_write() {
|
||||
async {
|
||||
let mut io = Simplex::new(Params {
|
||||
buf_size: DEFAULT_BUCKET_CAPACITY,
|
||||
tx_rate: 8,
|
||||
rx_rate: 1024 * 8,
|
||||
latency: 0,
|
||||
});
|
||||
|
||||
let data = vec![0; DEFAULT_BUCKET_CAPACITY];
|
||||
|
||||
// Burst write should accept the full buffer.
|
||||
assert_eq!(
|
||||
poll!(io.write(&data)).map(|r| r.unwrap()),
|
||||
Poll::Ready(data.len())
|
||||
);
|
||||
assert_eq!(io.write_bucket.tokens, 0);
|
||||
}
|
||||
.block_on()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplex_latency() {
|
||||
async {
|
||||
let mut io = Simplex::new(Params {
|
||||
buf_size: DEFAULT_BUCKET_CAPACITY,
|
||||
tx_rate: 8,
|
||||
rx_rate: 1024 * 8,
|
||||
latency: 2,
|
||||
});
|
||||
|
||||
let mut data = vec![0; DEFAULT_BUCKET_CAPACITY];
|
||||
io.write_all(&data).await.unwrap();
|
||||
|
||||
// No time has elapsed, so no data should be available.
|
||||
assert_eq!(poll!(io.read(&mut data)).map(|r| r.unwrap()), Poll::Pending);
|
||||
|
||||
// Latency still hasn't elapsed.
|
||||
MockClock::advance(Duration::from_millis(1));
|
||||
assert_eq!(poll!(io.read(&mut data)).map(|r| r.unwrap()), Poll::Pending);
|
||||
|
||||
// Latency has elapsed, so data should be available.
|
||||
MockClock::advance(Duration::from_millis(1));
|
||||
assert_eq!(
|
||||
poll!(io.read(&mut data)).map(|r| r.unwrap()),
|
||||
Poll::Ready(DEFAULT_BUCKET_CAPACITY)
|
||||
);
|
||||
}
|
||||
.block_on()
|
||||
}
|
||||
}
|
||||
10
futures-plex/Cargo.toml
Normal file
10
futures-plex/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "futures-plex"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
description = "Port of tokio's `SimplexStream` and `DuplexStream` for the `futures` ecosystem."
|
||||
|
||||
[dependencies]
|
||||
bytes = { version = "1" }
|
||||
futures-io = { version = "0.3" }
|
||||
futures-util = { version = "0.3", default-features = false, features = ["io"] }
|
||||
5
futures-plex/README.md
Normal file
5
futures-plex/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# futures-plex
|
||||
|
||||
Port of tokio's `SimplexStream` and `DuplexStream` for the `futures` ecosystem.
|
||||
|
||||
This crate provides in-memory implementations for `AsyncRead` and `AsyncWrite`.
|
||||
344
futures-plex/src/lib.rs
Normal file
344
futures-plex/src/lib.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
#![doc = include_str!("../README.md")]
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{self, Poll, Waker},
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures_io::{AsyncRead, AsyncWrite};
|
||||
use futures_util::{
|
||||
AsyncReadExt,
|
||||
io::{ReadHalf, WriteHalf},
|
||||
};
|
||||
|
||||
/// A bidirectional pipe to read and write bytes in memory.
|
||||
///
|
||||
/// A pair of `DuplexStream`s are created together, and they act as a "channel"
|
||||
/// that can be used as in-memory IO types. Writing to one of the pairs will
|
||||
/// allow that data to be read from the other, and vice versa.
|
||||
///
|
||||
/// # Closing a `DuplexStream`
|
||||
///
|
||||
/// If one end of the `DuplexStream` channel is dropped, any pending reads on
|
||||
/// the other side will continue to read data until the buffer is drained, then
|
||||
/// they will signal EOF by returning 0 bytes. Any writes to the other side,
|
||||
/// including pending ones (that are waiting for free space in the buffer) will
|
||||
/// return `Err(BrokenPipe)` immediately.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # async fn ex() -> std::io::Result<()> {
|
||||
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
/// let (mut client, mut server) = futures_plex::duplex(64);
|
||||
///
|
||||
/// client.write_all(b"ping").await?;
|
||||
///
|
||||
/// let mut buf = [0u8; 4];
|
||||
/// server.read_exact(&mut buf).await?;
|
||||
/// assert_eq!(&buf, b"ping");
|
||||
///
|
||||
/// server.write_all(b"pong").await?;
|
||||
///
|
||||
/// client.read_exact(&mut buf).await?;
|
||||
/// assert_eq!(&buf, b"pong");
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct DuplexStream {
|
||||
read: ReadHalf<SimplexStream>,
|
||||
write: WriteHalf<SimplexStream>,
|
||||
}
|
||||
|
||||
/// A unidirectional pipe to read and write bytes in memory.
|
||||
///
|
||||
/// It can be constructed by [`simplex`] function which will create a pair of
|
||||
/// reader and writer or by calling [`SimplexStream::new_unsplit`] that will
|
||||
/// create a handle for both reading and writing.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # async fn ex() -> std::io::Result<()> {
|
||||
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
/// let (mut receiver, mut sender) = futures_plex::simplex(64);
|
||||
///
|
||||
/// sender.write_all(b"ping").await?;
|
||||
///
|
||||
/// let mut buf = [0u8; 4];
|
||||
/// receiver.read_exact(&mut buf).await?;
|
||||
/// assert_eq!(&buf, b"ping");
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct SimplexStream {
|
||||
/// The buffer storing the bytes written, also read from.
|
||||
///
|
||||
/// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
|
||||
/// functionality already. Additionally, it can try to copy data in the
|
||||
/// same buffer if there read index has advanced far enough.
|
||||
buffer: BytesMut,
|
||||
/// Determines if the write side has been closed.
|
||||
is_closed: bool,
|
||||
/// The maximum amount of bytes that can be written before returning
|
||||
/// `Poll::Pending`.
|
||||
max_buf_size: usize,
|
||||
/// If the `read` side has been polled and is pending, this is the waker
|
||||
/// for that parked task.
|
||||
read_waker: Option<Waker>,
|
||||
/// If the `write` side has filled the `max_buf_size` and returned
|
||||
/// `Poll::Pending`, this is the waker for that parked task.
|
||||
write_waker: Option<Waker>,
|
||||
}
|
||||
|
||||
// ===== impl DuplexStream =====
|
||||
|
||||
/// Create a new pair of `DuplexStream`s that act like a pair of connected
|
||||
/// sockets.
|
||||
///
|
||||
/// The `max_buf_size` argument is the maximum amount of bytes that can be
|
||||
/// written to a side before the write returns `Poll::Pending`.
|
||||
pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
|
||||
let (read_0, write_0) = SimplexStream::new_unsplit(max_buf_size).split();
|
||||
let (read_1, write_1) = SimplexStream::new_unsplit(max_buf_size).split();
|
||||
|
||||
(
|
||||
DuplexStream {
|
||||
read: read_0,
|
||||
write: write_1,
|
||||
},
|
||||
DuplexStream {
|
||||
read: read_1,
|
||||
write: write_0,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
impl AsyncRead for DuplexStream {
|
||||
// Previous rustc required this `self` to be `mut`, even though newer
|
||||
// versions recognize it isn't needed to call `lock()`. So for
|
||||
// compatibility, we include the `mut` and `allow` the lint.
|
||||
//
|
||||
// See https://github.com/rust-lang/rust/issues/73592
|
||||
#[allow(unused_mut)]
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Pin::new(&mut self.read).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for DuplexStream {
|
||||
#[allow(unused_mut)]
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Pin::new(&mut self.write).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
Pin::new(&mut self.write).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
#[allow(unused_mut)]
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.write).poll_flush(cx)
|
||||
}
|
||||
|
||||
#[allow(unused_mut)]
|
||||
fn poll_close(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.write).poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl SimplexStream =====
|
||||
|
||||
/// Creates unidirectional buffer that acts like in memory pipe.
|
||||
///
|
||||
/// The `max_buf_size` argument is the maximum amount of bytes that can be
|
||||
/// written to a buffer before the it returns `Poll::Pending`.
|
||||
///
|
||||
/// # Reunite reader and writer
|
||||
///
|
||||
/// The reader and writer half can be unified into a single structure
|
||||
/// of `SimplexStream` that supports both reading and writing or
|
||||
/// the `SimplexStream` can be already created as unified structure
|
||||
/// using [`SimplexStream::new_unsplit()`].
|
||||
///
|
||||
/// ```
|
||||
/// # async fn ex() -> std::io::Result<()> {
|
||||
/// # use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
/// let (reader, writer) = futures_plex::simplex(64);
|
||||
/// let mut simplex_stream = reader.reunite(writer).unwrap();
|
||||
/// simplex_stream.write_all(b"hello").await?;
|
||||
///
|
||||
/// let mut buf = [0u8; 5];
|
||||
/// simplex_stream.read_exact(&mut buf).await?;
|
||||
/// assert_eq!(&buf, b"hello");
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
|
||||
SimplexStream::new_unsplit(max_buf_size).split()
|
||||
}
|
||||
|
||||
impl SimplexStream {
|
||||
/// Creates unidirectional buffer that acts like in memory pipe. To create
|
||||
/// split version with separate reader and writer you can use
|
||||
/// [`simplex`] function.
|
||||
///
|
||||
/// The `max_buf_size` argument is the maximum amount of bytes that can be
|
||||
/// written to a buffer before the it returns `Poll::Pending`.
|
||||
pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
|
||||
SimplexStream {
|
||||
buffer: BytesMut::new(),
|
||||
is_closed: false,
|
||||
max_buf_size,
|
||||
read_waker: None,
|
||||
write_waker: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn close_write(&mut self) {
|
||||
self.is_closed = true;
|
||||
// needs to notify any readers that no more data will come
|
||||
if let Some(waker) = self.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_read_internal(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
if self.buffer.has_remaining() {
|
||||
let len = self.buffer.remaining().min(buf.len());
|
||||
buf[..len].copy_from_slice(&self.buffer[..len]);
|
||||
self.buffer.advance(len);
|
||||
if len > 0 {
|
||||
// The passed `buf` might have been empty, don't wake up if
|
||||
// no bytes have been moved.
|
||||
if let Some(waker) = self.write_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(len))
|
||||
} else if self.is_closed {
|
||||
Poll::Ready(Ok(0))
|
||||
} else {
|
||||
self.read_waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_internal(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
if self.is_closed {
|
||||
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
|
||||
}
|
||||
let avail = self.max_buf_size - self.buffer.len();
|
||||
if avail == 0 {
|
||||
self.write_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
let len = buf.len().min(avail);
|
||||
self.buffer.extend_from_slice(&buf[..len]);
|
||||
if let Some(waker) = self.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
Poll::Ready(Ok(len))
|
||||
}
|
||||
|
||||
fn poll_write_vectored_internal(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
if self.is_closed {
|
||||
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
|
||||
}
|
||||
let avail = self.max_buf_size - self.buffer.len();
|
||||
if avail == 0 {
|
||||
self.write_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
let mut rem = avail;
|
||||
for buf in bufs {
|
||||
if rem == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let len = buf.len().min(rem);
|
||||
self.buffer.extend_from_slice(&buf[..len]);
|
||||
rem -= len;
|
||||
}
|
||||
|
||||
if let Some(waker) = self.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
Poll::Ready(Ok(avail - rem))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for SimplexStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.poll_read_internal(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for SimplexStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.poll_write_internal(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.poll_write_vectored_internal(cx, bufs)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
mut self: Pin<&mut Self>,
|
||||
_: &mut task::Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
self.close_write();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,2 @@
|
||||
imports_granularity = "Crate"
|
||||
wrap_comments = true
|
||||
wrap_comments = true
|
||||
|
||||
Reference in New Issue
Block a user