Initial yamux implementation.

This commit is contained in:
Toralf Wittner
2018-06-14 13:55:52 +02:00
parent 298cda7f4e
commit 038d2c536c
11 changed files with 1393 additions and 0 deletions

9
.editorconfig Normal file
View File

@@ -0,0 +1,9 @@
root = true
[*]
charset=utf-8
end_of_line=lf
indent_size=4
indent_style=space
max_line_length=100

17
Cargo.toml Normal file
View File

@@ -0,0 +1,17 @@
[package]
name = "yamux"
version = "0.1.0"
authors = ["Parity Technologies <admin@parity.io>"]
license = "MIT"
[dependencies]
bytes = "0.4"
futures = "0.1"
log = "0.4"
quick-error = "0.1"
tokio-io = "0.1"
[dev-dependencies]
env_logger = "0.5"
quickcheck = "0.6"
tokio = "0.1"

18
LICENSE Normal file
View File

@@ -0,0 +1,18 @@
Copyright 2018 Parity Technologies (UK) Ltd.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS
OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

443
src/connection.rs Normal file
View File

@@ -0,0 +1,443 @@
use error::ConnectionError;
use frame::{
codec::FrameCodec,
header::{ACK, ECODE_PROTO, FIN, Header, RST, SYN, Type},
Body,
Data,
Frame,
GoAway,
Ping,
RawFrame,
WindowUpdate
};
use futures::{prelude::*, self, sync::{mpsc, oneshot}};
use std::{
collections::BTreeMap,
sync::{atomic::{AtomicUsize, Ordering}, Arc},
u32,
usize
};
use stream::{Item, Stream, StreamId, Window};
use tokio_io::{codec::Framed, AsyncRead, AsyncWrite};
use Config;
/// Connection mode
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum Mode {
Client,
Server
}
// Commands sent from `Ctrl` to `Connection`.
enum Cmd {
OpenStream(Option<Body>, oneshot::Sender<Stream>)
}
/// `Ctrl` allows controlling some connection aspects, e.g. opening new streams.
#[derive(Clone)]
pub struct Ctrl {
sender: mpsc::Sender<Cmd>
}
impl Ctrl {
fn new(sender: mpsc::Sender<Cmd>) -> Ctrl {
Ctrl { sender }
}
/// Open a new stream optionally sending some initial data to the remote endpoint.
pub fn open_stream(&self, data: Option<Body>) -> impl Future<Item=Stream, Error=ConnectionError> {
let (tx, rx) = oneshot::channel();
self.sender.clone()
.send(Cmd::OpenStream(data, tx))
.map_err(|_| ConnectionError::Closed)
.and_then(move |_| rx.map_err(|_| ConnectionError::Closed))
}
}
// Handle to stream. Used by connection to deliver incoming data.
#[derive(Clone)]
struct StreamInbox {
recv_win: Arc<Window>,
items: mpsc::UnboundedSender<Item>,
ack: bool
}
/// A connection which multiplexes streams to the remote endpoint.
pub struct Connection<T> {
is_dead: bool,
label: &'static str,
mode: Mode,
resource: Framed<T, FrameCodec>,
config: Arc<Config>,
id_counter: AtomicUsize,
streams: BTreeMap<StreamId, StreamInbox>,
from_streams: mpsc::UnboundedReceiver<(StreamId, Item)>,
stream_sender: mpsc::UnboundedSender<(StreamId, Item)>,
from_ctrl: mpsc::Receiver<Cmd>,
ctrl: Ctrl,
pending: Option<RawFrame>
}
impl<T> Connection<T>
where
T: AsyncRead + AsyncWrite
{
/// Create a new connection either in client or server mode.
pub fn new(resource: T, config: Arc<Config>, mode: Mode) -> Self {
let seed = match mode {
Mode::Client => 1,
Mode::Server => 2
};
let (stream_tx, stream_rx) = mpsc::unbounded();
let (ctrl_tx, ctrl_rx) = mpsc::channel(1024);
Connection {
mode,
label: "",
is_dead: false,
resource: resource.framed(FrameCodec::new()),
config,
id_counter: AtomicUsize::new(seed),
streams: BTreeMap::new(),
from_streams: stream_rx,
stream_sender: stream_tx,
from_ctrl: ctrl_rx,
ctrl: Ctrl::new(ctrl_tx),
pending: None
}
}
/// Optionally set a label which shows up in log messages.
pub fn set_label(&mut self, label: &'static str) {
self.label = label
}
/// Get a control handle which allows to open new streams.
pub fn control(&self) -> Ctrl {
self.ctrl.clone()
}
fn open_stream(&mut self, data: Option<Body>) -> Result<(Stream, Frame<Data>), ConnectionError> {
let id = self.next_stream_id()?;
let credit = self.config.receive_window;
let stream = self.new_stream(id, credit);
let mut frame = Frame::data(id, data.unwrap_or(Body::empty()));
frame.header_mut().syn();
Ok((stream, frame))
}
fn on_item(&mut self, item: (StreamId, Item)) -> RawFrame {
let set_ack_flag = self.streams.get_mut(&item.0).map(|inbox| {
let prev = inbox.ack;
inbox.ack = false;
prev
}).unwrap_or(false);
match item.1 {
Item::Data(body) => {
let mut frame = Frame::data(item.0, body);
if set_ack_flag {
frame.header_mut().ack()
}
frame.into_raw()
}
Item::WindowUpdate(n) => {
let mut frame = Frame::window_update(item.0, n);
if set_ack_flag {
frame.header_mut().ack()
}
frame.into_raw()
}
Item::Reset => {
let mut header = Header::data(item.0, 0);
header.rst();
Frame::new(header).into_raw()
}
Item::Finish => {
let mut header = Header::data(item.0, 0);
header.fin();
Frame::new(header).into_raw()
}
}
}
fn on_data(&mut self, frame: Frame<Data>) -> Result<Option<Stream>, Frame<GoAway>> {
let stream_id = frame.header().id();
if frame.header().flags().contains(RST) {
self.on_reset(stream_id);
return Ok(None)
}
let is_finish = frame.header().flags().contains(FIN); // half-close
let body = frame.body().clone();
if frame.header().flags().contains(SYN) { // new stream
if !self.is_valid_remote_id(stream_id, Type::Data) {
warn!("{}invalid stream id {}", self.label, stream_id);
return Err(Frame::go_away(ECODE_PROTO))
}
let credit = self.config.receive_window;
if body.bytes().len() >= credit as usize {
warn!("{}initial data exceeds receive window", self.label);
return Err(Frame::go_away(ECODE_PROTO))
}
if self.streams.contains_key(&stream_id) {
warn!("{}stream {} already exists", self.label, stream_id);
return Err(Frame::go_away(ECODE_PROTO))
}
let stream = self.new_stream(stream_id, credit);
if is_finish {
assert!(self.deliver(stream_id, Item::Finish))
}
if body.bytes().len() > 0 {
assert!(self.deliver(stream_id, Item::Data(body)))
}
return Ok(Some(stream))
}
if !self.deliver(stream_id, Item::Data(body)) {
return Ok(None)
}
if is_finish {
self.on_finish(stream_id)
}
Ok(None)
}
fn on_window_update(&mut self, frame: Frame<WindowUpdate>) -> Result<Option<Stream>, Frame<GoAway>> {
let stream_id = frame.header().id();
if frame.header().flags().contains(RST) { // reset stream
self.on_reset(stream_id);
return Ok(None)
}
let credit = frame.header().credit();
let is_finish = frame.header().flags().contains(FIN); // half-close
if frame.header().flags().contains(SYN) { // new stream
if !self.is_valid_remote_id(stream_id, Type::WindowUpdate) {
warn!("{}invalid stream id {}", self.label, stream_id);
return Err(Frame::go_away(ECODE_PROTO))
}
if self.streams.contains_key(&stream_id) {
warn!("{}stream {} already exists", self.label, stream_id);
return Err(Frame::go_away(ECODE_PROTO))
}
let stream = self.new_stream(stream_id, credit);
if is_finish {
assert!(self.deliver(stream_id, Item::Finish))
}
return Ok(Some(stream))
}
if !self.deliver(stream_id, Item::WindowUpdate(credit)) {
return Ok(None)
}
if is_finish {
self.on_finish(stream_id)
}
Ok(None)
}
fn on_ping(&mut self, frame: Frame<Ping>) -> Result<Option<Frame<Ping>>, ConnectionError> {
let stream_id = frame.header().id();
if frame.header().flags().contains(ACK) { // pong
Ok(None) // TODO
} else {
if self.streams.contains_key(&stream_id) {
let mut hdr = Header::ping(frame.header().nonce());
hdr.ack();
Ok(Some(Frame::new(hdr)))
} else {
debug!("{}received ping for unknown stream {}", self.label, stream_id);
Ok(None)
}
}
}
fn on_go_away(&mut self, frame: Frame<GoAway>) {
info!("{}received go_away frame; error code = {}", self.label, frame.header().error_code());
self.terminate()
}
fn on_reset(&mut self, id: StreamId) {
self.deliver(id, Item::Reset);
}
fn on_finish(&mut self, id: StreamId) {
self.deliver(id, Item::Finish);
}
fn next_stream_id(&self) -> Result<StreamId, ConnectionError> {
if self.id_counter.load(Ordering::SeqCst) >= u32::MAX as usize - 2 {
return Err(ConnectionError::NoMoreStreamIds)
}
let proposed = StreamId::new(self.id_counter.fetch_add(2, Ordering::SeqCst) as u32);
match self.mode {
Mode::Client => assert!(proposed.is_client()),
Mode::Server => assert!(proposed.is_server())
}
Ok(proposed)
}
fn is_valid_remote_id(&self, id: StreamId, ty: Type) -> bool {
match ty {
Type::Ping | Type::GoAway => return id.is_session(),
_ => {}
}
match self.mode {
Mode::Client => id.is_server(),
Mode::Server => id.is_client()
}
}
fn new_stream(&mut self, id: StreamId, recv_window: u32) -> Stream {
let recv_win = Arc::new(Window::new(AtomicUsize::new(recv_window as usize)));
let (tx_stream, rx_stream) = mpsc::unbounded();
let inbox = StreamInbox {
recv_win: recv_win.clone(),
items: tx_stream,
ack: true
};
self.streams.insert(id, inbox);
Stream::new(id, self.config.clone(), self.stream_sender.clone(), rx_stream, recv_win)
}
fn deliver(&mut self, id: StreamId, item: Item) -> bool {
if let Some(ref inbox) = self.streams.get(&id) {
if inbox.items.unbounded_send(item).is_ok() {
return true
}
}
debug!("{}can not deliver; stream {} is gone", self.label, id);
self.streams.remove(&id);
false
}
fn terminate(&mut self) {
debug!("{}terminating connection", self.label);
self.is_dead = true;
self.streams.clear()
}
fn send(&mut self, frame: RawFrame) -> Poll<(), ConnectionError> {
trace!("{}send: {:?}", self.label, frame.header);
match self.resource.start_send(frame) {
Ok(AsyncSink::Ready) => Ok(Async::Ready(())),
Ok(AsyncSink::NotReady(frame)) => {
self.pending = Some(frame);
Ok(Async::NotReady)
}
Err(e) => {
self.terminate();
Err(e.into())
}
}
}
fn flush(&mut self) -> Poll<(), ConnectionError> {
self.resource.poll_complete().map_err(|e| {
self.terminate();
e.into()
})
}
}
impl<T> futures::Stream for Connection<T>
where
T: AsyncRead + AsyncWrite
{
type Item = Stream;
type Error = ConnectionError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.is_dead {
return Ok(Async::Ready(None))
}
// First, check for pending frames we need to send.
if let Some(frame) = self.pending.take() {
try_ready!(self.send(frame))
}
// Check for control commands.
while let Ok(Async::Ready(Some(command))) = self.from_ctrl.poll() {
match command {
Cmd::OpenStream(body, tx) => {
trace!("{}open stream", self.label);
match self.open_stream(body) {
Ok((stream, frame)) => {
let _ = tx.send(stream);
try_ready!(self.send(frame.into_raw()))
}
Err(e) => {
self.terminate();
return Err(e)
}
}
}
}
}
// Check for items streams want to send.
while let Ok(Async::Ready(Some(item))) = self.from_streams.poll() {
let raw_frame = self.on_item(item);
try_ready!(self.send(raw_frame))
}
// Finally, check for incoming data from remote.
loop {
try_ready!(self.flush());
let to_check = self.resource.poll();
match to_check {
Ok(Async::Ready(Some(frame))) => {
trace!("{}recv: {:?}", self.label, frame.header);
match frame.dyn_type() {
Type::Data => {
match self.on_data(Frame::assert(frame)) {
Ok(None) => {}
Ok(Some(stream)) => return Ok(Async::Ready(Some(stream))),
Err(frame) => try_ready!(self.send(frame.into_raw()))
}
}
Type::WindowUpdate => {
match self.on_window_update(Frame::assert(frame)) {
Ok(None) => {}
Ok(Some(stream)) => return Ok(Async::Ready(Some(stream))),
Err(frame) => try_ready!(self.send(frame.into_raw()))
}
}
Type::Ping => {
match self.on_ping(Frame::assert(frame)) {
Ok(None) => {}
Ok(Some(pong)) => try_ready!(self.send(pong.into_raw())),
Err(e) => {
self.terminate();
return Err(e)
}
}
}
Type::GoAway => {
self.on_go_away(Frame::assert(frame));
return Ok(Async::Ready(None))
}
}
}
Ok(Async::Ready(None)) => {
self.terminate();
return Ok(Async::Ready(None))
}
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => {
self.terminate();
return Err(e.into())
}
}
}
}
}

65
src/error.rs Normal file
View File

@@ -0,0 +1,65 @@
use std::io;
use stream::StreamId;
quick_error! {
#[derive(Debug)]
pub enum DecodeError {
Io(e: io::Error) {
display("i/o error: {}", e)
cause(e)
from()
}
Type(t: u8) {
display("unkown type: {}", t)
}
#[doc(hidden)]
__Nonexhaustive
}
}
quick_error! {
#[derive(Debug)]
pub enum StreamError {
StreamClosed(id: StreamId) {
display("stream {} is closed", id)
}
ConnectionClosed {
display("connection of this stream is closed")
}
BodyTooLarge {
display("body size exceeds allowed maximum")
}
#[doc(hidden)]
__Nonexhaustive
}
}
quick_error! {
#[derive(Debug)]
pub enum ConnectionError {
Io(e: io::Error) {
display("i/o error: {}", e)
cause(e)
from()
}
Decode(e: DecodeError) {
display("decode error: {}", e)
cause(e)
from()
}
Protocol(error_code: u32) {
display("protocol error {}", error_code)
}
NoMoreStreamIds {
display("number of stream ids has been exhausted")
}
Closed {
display("connection is closed")
}
#[doc(hidden)]
__Nonexhaustive
}
}

158
src/frame/codec.rs Normal file
View File

@@ -0,0 +1,158 @@
use bytes::{BigEndian, BufMut, ByteOrder, BytesMut};
use error::DecodeError;
use frame::{header::{Flags, Len, RawHeader, Type, Version}, Body, RawFrame};
use std::io;
use stream::StreamId;
use tokio_io::codec::{BytesCodec, Decoder, Encoder};
#[derive(Debug)]
pub struct FrameCodec {
header_codec: HeaderCodec,
body_codec: BytesCodec
}
impl FrameCodec {
pub fn new() -> FrameCodec {
FrameCodec {
header_codec: HeaderCodec::new(),
body_codec: BytesCodec::new()
}
}
}
impl Encoder for FrameCodec {
type Item = RawFrame;
type Error = io::Error;
fn encode(&mut self, frame: Self::Item, bytes: &mut BytesMut) -> Result<(), Self::Error> {
self.header_codec.encode(frame.header, bytes)?;
self.body_codec.encode(frame.body.0, bytes)
}
}
impl Decoder for FrameCodec {
type Item = RawFrame;
type Error = DecodeError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let header =
if let Some(header) = self.header_codec.decode(src)? {
header
} else {
return Ok(None)
};
if header.typ != Type::Data || header.length.0 == 0 {
return Ok(Some(RawFrame { header, body: Body::empty() }))
}
let len = header.length.0 as usize;
if src.len() < len {
return Ok(None)
}
if let Some(b) = self.body_codec.decode(&mut src.split_to(len))? {
Ok(Some(RawFrame { header, body: Body(b.freeze()) }))
} else {
Ok(None)
}
}
}
#[derive(Debug)]
pub struct HeaderCodec(());
impl HeaderCodec {
pub fn new() -> HeaderCodec {
HeaderCodec(())
}
}
impl Encoder for HeaderCodec {
type Item = RawHeader;
type Error = io::Error;
fn encode(&mut self, hdr: Self::Item, bytes: &mut BytesMut) -> Result<(), Self::Error> {
bytes.reserve(12);
bytes.put_u8(hdr.version.0);
bytes.put_u8(hdr.typ as u8);
bytes.put_u16_be(hdr.flags.0);
bytes.put_u32_be(hdr.stream_id.as_u32());
bytes.put_u32_be(hdr.length.0);
Ok(())
}
}
impl Decoder for HeaderCodec {
type Item = RawHeader;
type Error = DecodeError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 12 {
return Ok(None)
}
let src = src.split_to(12);
let header = RawHeader {
version: Version(src[0]),
typ: match src[1] {
0 => Type::Data,
1 => Type::WindowUpdate,
2 => Type::Ping,
3 => Type::GoAway,
t => return Err(DecodeError::Type(t))
},
flags: Flags(BigEndian::read_u16(&src[2..4])),
stream_id: StreamId::new(BigEndian::read_u32(&src[4..8])),
length: Len(BigEndian::read_u32(&src[8..12]))
};
Ok(Some(header))
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use quickcheck::{Arbitrary, Gen, quickcheck};
use super::*;
impl Arbitrary for RawFrame {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
use frame::header::Type::*;
let ty = g.choose(&[Data, WindowUpdate, Ping, GoAway]).unwrap().clone();
let len = g.gen::<u16>() as u32;
let header = RawHeader {
version: Version(g.gen()),
typ: ty,
flags: Flags(g.gen()),
stream_id: StreamId::new(g.gen()),
length: Len(len)
};
let body =
if ty == Type::Data {
let bytes = Bytes::from(vec![0; len as usize]);
Body::from_bytes(bytes).unwrap()
} else {
Body::empty()
};
RawFrame { header, body }
}
}
#[test]
fn frame_identity() {
fn property(f: RawFrame) -> bool {
let mut buf = BytesMut::with_capacity(12 + f.body.bytes().len());
let mut codec = FrameCodec::new();
if codec.encode(f.clone(), &mut buf).is_err() {
return false
}
if let Ok(x) = codec.decode(&mut buf) {
x == Some(f)
} else {
false
}
}
quickcheck(property as fn(RawFrame) -> bool)
}
}

201
src/frame/header.rs Normal file
View File

@@ -0,0 +1,201 @@
use std::marker::PhantomData;
use stream::StreamId;
use super::{Data, WindowUpdate, Ping, GoAway};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Type {
Data,
WindowUpdate,
Ping,
GoAway
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Version(pub u8);
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Len(pub u32);
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Flags(pub u16);
impl Flags {
pub fn contains(&self, other: Flags) -> bool {
self.0 & other.0 == other.0
}
pub fn and(&self, other: Flags) -> Flags {
Flags(self.0 | other.0)
}
}
/// Protocol error code for use with GoAway frames.
pub const ECODE_PROTO: u32 = 1;
/// Internal error code for use with GoAway frames.
pub const ECODE_INTERNAL: u32 = 2;
pub const SYN: Flags = Flags(1);
pub const ACK: Flags = Flags(2);
pub const FIN: Flags = Flags(4);
pub const RST: Flags = Flags(8);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RawHeader {
pub version: Version,
pub typ: Type,
pub flags: Flags,
pub stream_id: StreamId,
pub length: Len
}
#[derive(Clone, Debug)]
pub struct Header<T> {
raw_header: RawHeader,
header_type: PhantomData<T>
}
impl<T> Header<T> {
pub(crate) fn assert(raw: RawHeader) -> Self {
Header {
raw_header: raw,
header_type: PhantomData
}
}
pub fn id(&self) -> StreamId {
self.raw_header.stream_id
}
pub fn flags(&self) -> Flags {
self.raw_header.flags
}
pub fn into_raw(self) -> RawHeader {
self.raw_header
}
}
impl Header<Data> {
pub fn data(id: StreamId, len: u32) -> Self {
Header {
raw_header: RawHeader {
version: Version(0),
typ: Type::Data,
flags: Flags(0),
stream_id: id,
length: Len(len)
},
header_type: PhantomData
}
}
pub fn syn(&mut self) {
self.raw_header.flags.0 |= SYN.0
}
pub fn ack(&mut self) {
self.raw_header.flags.0 |= ACK.0
}
pub fn fin(&mut self) {
self.raw_header.flags.0 |= FIN.0
}
pub fn rst(&mut self) {
self.raw_header.flags.0 |= RST.0
}
pub fn len(&self) -> u32 {
self.raw_header.length.0
}
}
impl Header<WindowUpdate> {
pub fn window_update(id: StreamId, credit: u32) -> Self {
Header {
raw_header: RawHeader {
version: Version(0),
typ: Type::WindowUpdate,
flags: Flags(0),
stream_id: id,
length: Len(credit)
},
header_type: PhantomData
}
}
pub fn syn(&mut self) {
self.raw_header.flags.0 |= SYN.0
}
pub fn ack(&mut self) {
self.raw_header.flags.0 |= ACK.0
}
pub fn fin(&mut self) {
self.raw_header.flags.0 |= FIN.0
}
pub fn rst(&mut self) {
self.raw_header.flags.0 |= RST.0
}
pub fn credit(&self) -> u32 {
self.raw_header.length.0
}
}
impl Header<Ping> {
pub fn ping(nonce: u32) -> Self {
Header {
raw_header: RawHeader {
version: Version(0),
typ: Type::Ping,
flags: Flags(0),
stream_id: StreamId::new(0),
length: Len(nonce)
},
header_type: PhantomData
}
}
pub fn syn(&mut self) {
self.raw_header.flags.0 |= SYN.0
}
pub fn ack(&mut self) {
self.raw_header.flags.0 |= ACK.0
}
pub fn nonce(&self) -> u32 {
self.raw_header.length.0
}
}
impl Header<GoAway> {
pub fn go_away(error_code: u32) -> Self {
Header {
raw_header: RawHeader {
version: Version(0),
typ: Type::GoAway,
flags: Flags(0),
stream_id: StreamId::new(0),
length: Len(error_code)
},
header_type: PhantomData
}
}
pub fn error_code(&self) -> u32 {
self.raw_header.length.0
}
}

119
src/frame/mod.rs Normal file
View File

@@ -0,0 +1,119 @@
use std::u32;
use bytes::Bytes;
use self::header::{Header, RawHeader};
use stream::StreamId;
pub mod codec;
pub mod header;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RawFrame {
pub header: RawHeader,
pub body: Body
}
impl RawFrame {
pub fn dyn_type(&self) -> header::Type {
self.header.typ
}
}
pub enum Data {}
pub enum WindowUpdate {}
pub enum Ping {}
pub enum GoAway {}
#[derive(Clone, Debug)]
pub struct Frame<T> {
header: Header<T>,
body: Body
}
impl<T> Frame<T> {
pub(crate) fn assert(raw: RawFrame) -> Self {
Frame {
header: Header::assert(raw.header),
body: raw.body
}
}
pub fn new(header: Header<T>) -> Frame<T> {
Frame { header, body: Body::empty() }
}
pub fn header(&self) -> &Header<T> {
&self.header
}
pub fn header_mut(&mut self) -> &mut Header<T> {
&mut self.header
}
pub fn into_raw(self) -> RawFrame {
RawFrame {
header: self.header.into_raw(),
body: self.body
}
}
}
impl Frame<Data> {
pub fn data(id: StreamId, b: Body) -> Self {
Frame {
header: Header::data(id, b.0.len() as u32),
body: b
}
}
pub fn body(&self) -> &Body {
&self.body
}
}
impl Frame<WindowUpdate> {
pub fn window_update(id: StreamId, n: u32) -> Self {
Frame {
header: Header::window_update(id, n),
body: Body::empty()
}
}
}
impl Frame<GoAway> {
pub fn go_away(error: u32) -> Self {
Frame {
header: Header::go_away(error),
body: Body::empty()
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Body(Bytes);
impl Body {
pub fn empty() -> Body {
Body(Bytes::new())
}
pub fn from_bytes(b: Bytes) -> Option<Body> {
if b.len() < u32::MAX as usize {
Some(Body(b))
} else {
None
}
}
pub fn bytes(&self) -> &Bytes {
&self.0
}
pub fn into_bytes(self) -> Bytes {
self.0
}
}

34
src/lib.rs Normal file
View File

@@ -0,0 +1,34 @@
extern crate bytes;
#[macro_use]
extern crate futures;
#[macro_use]
extern crate log;
#[cfg(test)]
extern crate quickcheck;
#[macro_use]
extern crate quick_error;
extern crate tokio_io;
mod connection;
mod stream;
pub mod error;
pub mod frame;
pub use connection::{Connection, Ctrl, Mode};
pub use frame::Body;
pub use stream::{Stream, StreamId};
#[derive(Debug)]
pub struct Config {
pub receive_window: u32
}
impl Default for Config {
fn default() -> Self {
Config {
receive_window: 256 * 1024
}
}
}

245
src/stream.rs Normal file
View File

@@ -0,0 +1,245 @@
use bytes::Bytes;
use error::StreamError;
use frame::Body;
use futures::{self, prelude::*, sync::mpsc, task::{self, Task}};
use std::{fmt, sync::{atomic::{AtomicUsize, Ordering}, Arc}, u32, usize};
use Config;
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct StreamId(u32);
impl StreamId {
pub(crate) fn new(id: u32) -> StreamId {
StreamId(id)
}
pub fn is_server(&self) -> bool {
self.0 % 2 == 0
}
pub fn is_client(&self) -> bool {
!self.is_server()
}
pub fn is_session(&self) -> bool {
self.0 == 0
}
pub fn as_u32(&self) -> u32 {
self.0
}
}
impl fmt::Display for StreamId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum State {
Open,
SendClosed,
RecvClosed,
Closed
}
#[derive(Debug)]
pub struct Window(AtomicUsize);
impl Window {
pub fn new(n: AtomicUsize) -> Window {
Window(n)
}
pub fn decrement(&self, amount: usize) -> usize {
loop {
let prev = self.0.load(Ordering::SeqCst);
let next = prev.checked_sub(amount).unwrap_or(0);
if self.0.compare_and_swap(prev, next, Ordering::SeqCst) == prev {
return next
}
}
}
pub fn set(&self, val: usize) {
self.0.store(val, Ordering::SeqCst)
}
}
pub enum Item {
Data(Body),
WindowUpdate(u32),
Reset,
Finish
}
pub type Sender = mpsc::UnboundedSender<(StreamId, Item)>;
pub type Receiver = mpsc::UnboundedReceiver<Item>;
pub struct Stream {
id: StreamId,
state: State,
config: Arc<Config>,
recv_window: Arc<Window>,
send_window: u32,
outgoing: Option<Bytes>,
sender: Sender,
receiver: Receiver,
writer_task: Option<Task>
}
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Stream {{ id: {}, state: {:?} }}", self.id, self.state)
}
}
impl Stream {
pub(crate) fn new(id: StreamId, c: Arc<Config>, s: Sender, r: Receiver, rw: Arc<Window>) -> Stream {
let send_window = c.receive_window;
Stream {
id,
state: State::Open,
config: c,
recv_window: rw,
send_window,
outgoing: None,
sender: s,
receiver: r,
writer_task: None
}
}
pub fn id(&self) -> StreamId {
self.id
}
pub fn reset(mut self) -> Result<(), StreamError> {
if self.state == State::Closed || self.state == State::SendClosed {
return Err(StreamError::StreamClosed(self.id))
}
self.send_item(Item::Reset)
}
pub fn finish(&mut self) -> Result<(), StreamError> {
if self.state == State::Closed || self.state == State::SendClosed {
return Err(StreamError::StreamClosed(self.id))
}
self.send_item(Item::Finish)?;
if self.state == State::RecvClosed {
self.state = State::Closed
} else {
self.state = State::SendClosed
}
Ok(())
}
fn send_item(&mut self, item: Item) -> Result<(), StreamError> {
if self.sender.unbounded_send((self.id, item)).is_err() {
self.state = State::Closed;
return Err(StreamError::ConnectionClosed)
}
Ok(())
}
}
impl futures::Stream for Stream {
type Item = Bytes;
type Error = StreamError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.state == State::Closed || self.state == State::RecvClosed {
return Ok(Async::Ready(None))
}
match self.receiver.poll() {
Err(()) => {
self.state = State::RecvClosed;
return Err(StreamError::StreamClosed(self.id))
}
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(item)) => match item {
Some(Item::Data(body)) => {
let remaining = self.recv_window.decrement(body.bytes().len());
if remaining == 0 {
let item = Item::WindowUpdate(self.config.receive_window);
self.send_item(item)?;
self.recv_window.set(self.config.receive_window as usize);
}
Ok(Async::Ready(Some(body.into_bytes())))
}
Some(Item::WindowUpdate(n)) => {
self.send_window = self.send_window.checked_add(n).unwrap_or(u32::MAX);
if let Some(writer) = self.writer_task.take() {
writer.notify()
}
Ok(Async::NotReady)
}
Some(Item::Finish) => {
if self.state == State::SendClosed {
self.state = State::Closed
} else {
self.state = State::RecvClosed
}
Ok(Async::Ready(None))
}
Some(Item::Reset) | None => {
self.state = State::Closed;
Ok(Async::Ready(None))
}
}
}
}
}
impl futures::Sink for Stream {
type SinkItem = Bytes;
type SinkError = StreamError;
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
if self.state == State::Closed || self.state == State::SendClosed {
return Err(StreamError::StreamClosed(self.id))
}
if self.outgoing.is_some() {
self.poll_complete()?;
}
if self.outgoing.is_some() {
return Ok(AsyncSink::NotReady(item))
}
self.outgoing = Some(item);
Ok(AsyncSink::Ready)
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
if self.state == State::Closed || self.state == State::SendClosed {
return Err(StreamError::StreamClosed(self.id))
}
if self.send_window == 0 {
self.writer_task = Some(task::current());
return Ok(Async::NotReady)
}
if let Some(mut b) = self.outgoing.take() {
if b.len() < self.send_window as usize {
self.send_window -= b.len() as u32;
let body = Body::from_bytes(b).ok_or(StreamError::BodyTooLarge)?;
self.send_item(Item::Data(body))?;
return Ok(Async::Ready(()))
}
let bytes = b.split_to(self.send_window as usize);
self.send_window = 0;
let body = Body::from_bytes(bytes).ok_or(StreamError::BodyTooLarge)?;
self.send_item(Item::Data(body))?;
self.outgoing = Some(b);
self.writer_task = Some(task::current());
Ok(Async::NotReady)
} else {
Ok(Async::Ready(()))
}
}
}

84
tests/smoke.rs Normal file
View File

@@ -0,0 +1,84 @@
#[macro_use]
extern crate log;
extern crate env_logger;
extern crate futures;
extern crate tokio;
extern crate yamux;
use futures::{future::{self, Loop}, prelude::*, stream};
use std::sync::Arc;
use tokio::{net::{TcpListener, TcpStream}, runtime::Runtime};
use yamux::{Body, Config, Connection, Mode};
fn server_conn(addr: &str, cfg: Arc<Config>) -> impl Future<Item=Connection<TcpStream>, Error=()> {
TcpListener::bind(&addr.parse().unwrap())
.unwrap()
.incoming()
.map(move |sock| Connection::new(sock, cfg.clone(), Mode::Server))
.into_future()
.map_err(|(e, _rem)| error!("accept failed: {}", e))
.and_then(|(maybe, _rem)| maybe.ok_or(()))
}
fn client_conn(addr: &str, cfg: Arc<Config>) -> impl Future<Item=Connection<TcpStream>, Error=()> {
let address = addr.parse().unwrap();
TcpStream::connect(&address)
.map_err(|e| error!("connect failed: {}", e))
.map(move |sock| Connection::new(sock, cfg.clone(), Mode::Client))
}
#[test]
fn connect_two_endpoints() {
let _ = env_logger::try_init();
let cfg = Arc::new(Config::default());
let mut rt = Runtime::new().unwrap();
let echo_stream_ids = server_conn("127.0.0.1:12345", cfg.clone())
.and_then(|mut conn| {
conn.set_label("S: ");
conn.for_each(|stream| {
debug!("S: new stream {}", stream.id());
let body = vec![
"Hi client!".as_bytes().into(),
format!("{}", stream.id()).as_bytes().into()
];
stream.send_all(stream::iter_ok(body)).map(|_| ())
.or_else(|e| {
error!("S: stream error: {}", e);
Ok(())
})
})
.map_err(|e| error!("S: connection error: {}", e))
});
let client = client_conn("127.0.0.1:12345", cfg.clone()).and_then(|mut conn| {
conn.set_label("C: ");
let ctrl = conn.control();
let future = conn.for_each(|_stream| Ok(()))
.map_err(|e| error!("C: connection error: {}", e));
tokio::spawn(future);
future::loop_fn((0, ctrl), |(i, ctrl)| {
ctrl.open_stream(Some(Body::from_bytes("Hi server!".as_bytes().into()).unwrap()))
.map_err(|e| error!("C: error opening stream: {}", e))
.and_then(move |stream| {
stream.into_future().map(|(data, _rem)| {
debug!("C: received {:?}", data)
})
.map_err(|(e, _rem)| error!("C: stream error: {}", e))
.and_then(move |()| {
if i == 2 {
debug!("C: done");
Ok(Loop::Break(()))
} else {
Ok(Loop::Continue((i + 1, ctrl)))
}
})
})
})
});
rt.spawn(echo_stream_ids);
rt.block_on(client).unwrap();
}