refactoring client for async

This commit is contained in:
sinuio
2022-05-19 10:59:16 -07:00
parent cc95a5c5b1
commit 858b35ec60
8 changed files with 351 additions and 243 deletions

View File

@@ -16,9 +16,13 @@ name = "tls_aio"
rustversion = { version = "1.0.6", optional = true }
[dependencies]
async-trait = "0.1.53"
futures = "0.3.21"
log = { version = "0.4.4", optional = true }
ring = "0.16.20"
sct = "0.7.0"
tokio = "1.18.2"
tokio-util = "0.7.2"
webpki = { version = "0.22.0", features = ["alloc", "std"] }
[features]

View File

@@ -18,10 +18,12 @@ use super::hs;
use std::convert::TryFrom;
use std::error::Error as StdError;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{fmt, io, mem};
use tokio::io::AsyncWrite;
/// A trait for the ability to store client session data.
/// The keys and values are opaque.
@@ -372,11 +374,12 @@ impl EarlyData {
/// Stub that implements io::Write and dispatches to `write_early_data`.
pub struct WriteEarlyData<'a> {
sess: &'a mut ClientConnection,
fut: Option<futures::future::BoxFuture<'a, io::Result<usize>>>,
}
impl<'a> WriteEarlyData<'a> {
fn new(sess: &'a mut ClientConnection) -> WriteEarlyData<'a> {
WriteEarlyData { sess }
WriteEarlyData { sess, fut: None }
}
/// How many bytes you may send. Writes will become short
@@ -386,19 +389,31 @@ impl<'a> WriteEarlyData<'a> {
}
}
impl<'a> io::Write for WriteEarlyData<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sess.write_early_data(buf)
impl<'a> AsyncWrite for WriteEarlyData<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let fut = match self.fut {
Some(fut) => fut,
None => Box::pin(self.sess.write_early_data(buf)),
};
fut.as_mut().poll(cx)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
/// This represents a single TLS client connection.
pub struct ClientConnection {
inner: ConnectionCommon<ClientConnectionData>,
inner: ConnectionCommon,
}
impl fmt::Debug for ClientConnection {
@@ -411,11 +426,11 @@ impl ClientConnection {
/// Make a new ClientConnection. `config` controls how
/// we behave in the TLS protocol, `name` is the
/// name of the server we want to talk to.
pub fn new(config: Arc<ClientConfig>, name: ServerName) -> Result<Self, Error> {
Self::new_inner(config, name, Vec::new(), Protocol::Tcp)
pub async fn new(config: Arc<ClientConfig>, name: ServerName) -> Result<Self, Error> {
Self::new_inner(config, name, Vec::new(), Protocol::Tcp).await
}
fn new_inner(
async fn new_inner(
config: Arc<ClientConfig>,
name: ServerName,
extra_exts: Vec<ClientExtension>,
@@ -430,7 +445,7 @@ impl ClientConnection {
data: &mut data,
};
let state = hs::start_handshake(name, extra_exts, config, &mut cx)?;
let state = hs::start_handshake(name, extra_exts, config, &mut cx).await?;
let inner = ConnectionCommon::new(state, data, common_state);
Ok(Self { inner })
@@ -471,17 +486,23 @@ impl ClientConnection {
self.inner.data.early_data.is_accepted()
}
fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
self.inner
.data
.early_data
.check_write(data.len())
.map(|sz| self.inner.common_state.send_early_plaintext(&data[..sz]))
async fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
let sz = self.inner.data.early_data.check_write(data.len());
if let Ok(sz) = sz {
Ok(self
.inner
.common_state
.send_early_plaintext(&data[..sz])
.await)
} else {
sz
}
}
}
impl Deref for ClientConnection {
type Target = ConnectionCommon<ClientConnectionData>;
type Target = ConnectionCommon;
fn deref(&self) -> &Self::Target {
&self.inner

View File

@@ -33,11 +33,12 @@ use crate::client::client_conn::ClientConnectionData;
use crate::client::common::ClientHelloDetails;
use crate::client::{tls13, ClientConfig, ServerName};
use async_trait::async_trait;
use std::sync::Arc;
pub(super) type NextState = Box<dyn State<ClientConnectionData>>;
pub(super) type NextStateOrError = Result<NextState, Error>;
pub(super) type ClientContext<'a> = crate::conn::Context<'a, ClientConnectionData>;
pub(super) type ClientContext<'a> = crate::conn::Context<'a>;
fn find_session(
server_name: &ServerName,
@@ -68,7 +69,7 @@ fn find_session(
.and_then(|resuming| Some(resuming))
}
pub(super) fn start_handshake(
pub(super) async fn start_handshake(
server_name: ServerName,
extra_exts: Vec<ClientExtension>,
config: Arc<ClientConfig>,
@@ -132,7 +133,8 @@ pub(super) fn start_handshake(
extra_exts,
may_send_sct_list,
None,
))
)
.await)
}
struct ExpectServerHello {
@@ -155,7 +157,7 @@ struct ExpectServerHelloOrHelloRetryRequest {
extra_exts: Vec<ClientExtension>,
}
fn emit_client_hello_for_retry(
async fn emit_client_hello_for_retry(
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
@@ -403,8 +405,13 @@ pub(super) fn sct_list_is_invalid(scts: &SCTList) -> bool {
scts.is_empty() || scts.iter().any(|sct| sct.0.is_empty())
}
#[async_trait]
impl State<ClientConnectionData> for ExpectServerHello {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> NextStateOrError {
let server_hello =
require_handshake_msg!(m, HandshakeType::ServerHello, HandshakePayload::ServerHello)?;
trace!("We got ServerHello {:#?}", server_hello);
@@ -552,6 +559,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
self.offered_key_share.unwrap(),
self.sent_tls13_fake_ccs,
)
.await
}
#[cfg(feature = "tls12")]
SupportedCipherSuite::Tls12(suite) => {
@@ -571,6 +579,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
transcript,
}
.handle_server_hello(cx, suite, server_hello, tls13_supported)
.await
}
}
}
@@ -581,7 +590,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
Box::new(self.next)
}
fn handle_hello_retry_request(
async fn handle_hello_retry_request(
self,
cx: &mut ClientContext<'_>,
m: Message,
@@ -707,21 +716,23 @@ impl ExpectServerHelloOrHelloRetryRequest {
self.extra_exts,
may_send_sct_list,
Some(cs),
))
)
.await)
}
}
#[async_trait]
impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
async fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
match m.payload {
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::ServerHello(..),
..
}) => self.into_expect_server_hello().handle(cx, m),
}) => self.into_expect_server_hello().handle(cx, m).await,
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::HelloRetryRequest(..),
..
}) => self.handle_hello_retry_request(cx, m),
}) => self.handle_hello_retry_request(cx, m).await,
payload => Err(inappropriate_handshake_message(
&payload,
&[ContentType::Handshake],

View File

@@ -29,6 +29,7 @@ use crate::client::{hs, ClientConfig, ServerName};
use ring::agreement::PublicKey;
use ring::constant_time;
use async_trait::async_trait;
use std::sync::Arc;
pub(super) use server_hello::CompleteServerHelloHandling;
@@ -50,16 +51,14 @@ mod server_hello {
}
impl CompleteServerHelloHandling {
pub(in crate::client) fn handle_server_hello(
pub(in crate::client) async fn handle_server_hello(
mut self,
cx: &mut ClientContext,
cx: &mut ClientContext<'_>,
suite: &'static Tls12CipherSuite,
server_hello: &ServerHelloPayload,
tls13_supported: bool,
) -> hs::NextStateOrError {
server_hello
.random
.write_slice(&mut self.randoms.server);
server_hello.random.write_slice(&mut self.randoms.server);
// Look for TLS1.3 downgrade signal in server random
// both the server random and TLS12_DOWNGRADE_SENTINEL are
@@ -132,8 +131,7 @@ mod server_hello {
&secrets.randoms.client,
&secrets.master_secret,
);
cx.common
.start_encryption_tls12(&secrets, Side::Client);
cx.common.start_encryption_tls12(&secrets, Side::Client);
// Since we're resuming, we verified the certificate and
// proof of possession in the prior session.
@@ -203,8 +201,9 @@ struct ExpectCertificate {
server_cert_sct_list: Option<SCTList>,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificate {
fn handle(
async fn handle(
mut self: Box<Self>,
_cx: &mut ClientContext<'_>,
m: Message,
@@ -264,46 +263,57 @@ struct ExpectCertificateStatusOrServerKx {
must_issue_new_ticket: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
match m.payload {
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::ServerKeyExchange(..),
..
}) => Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
using_ems: self.using_ems,
transcript: self.transcript,
suite: self.suite,
server_cert: ServerCertDetails::new(
self.server_cert_chain,
vec![],
self.server_cert_sct_list,
),
must_issue_new_ticket: self.must_issue_new_ticket,
})
.handle(cx, m),
}) => {
Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
using_ems: self.using_ems,
transcript: self.transcript,
suite: self.suite,
server_cert: ServerCertDetails::new(
self.server_cert_chain,
vec![],
self.server_cert_sct_list,
),
must_issue_new_ticket: self.must_issue_new_ticket,
})
.handle(cx, m)
.await
}
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::CertificateStatus(..),
..
}) => Box::new(ExpectCertificateStatus {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
using_ems: self.using_ems,
transcript: self.transcript,
suite: self.suite,
server_cert_sct_list: self.server_cert_sct_list,
server_cert_chain: self.server_cert_chain,
must_issue_new_ticket: self.must_issue_new_ticket,
})
.handle(cx, m),
}) => {
Box::new(ExpectCertificateStatus {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
using_ems: self.using_ems,
transcript: self.transcript,
suite: self.suite,
server_cert_sct_list: self.server_cert_sct_list,
server_cert_chain: self.server_cert_chain,
must_issue_new_ticket: self.must_issue_new_ticket,
})
.handle(cx, m)
.await
}
payload => Err(inappropriate_handshake_message(
&payload,
&[ContentType::Handshake],
@@ -330,8 +340,9 @@ struct ExpectCertificateStatus {
must_issue_new_ticket: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificateStatus {
fn handle(
async fn handle(
mut self: Box<Self>,
_cx: &mut ClientContext<'_>,
m: Message,
@@ -383,8 +394,13 @@ struct ExpectServerKx {
must_issue_new_ticket: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectServerKx {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
let opaque_kx = require_handshake_msg!(
m,
HandshakeType::ServerKeyExchange,
@@ -392,13 +408,10 @@ impl State<ClientConnectionData> for ExpectServerKx {
)?;
self.transcript.add_message(&m);
let ecdhe = opaque_kx
.unwrap_given_kxa(&self.suite.kx)
.ok_or_else(|| {
cx.common
.send_fatal_alert(AlertDescription::DecodeError);
Error::CorruptMessagePayload(ContentType::Handshake)
})?;
let ecdhe = opaque_kx.unwrap_given_kxa(&self.suite.kx).ok_or_else(|| {
cx.common.send_fatal_alert(AlertDescription::DecodeError);
Error::CorruptMessagePayload(ContentType::Handshake)
})?;
// Save the signature and signed parameters for later verification.
let mut kx_params = Vec::new();
@@ -548,8 +561,13 @@ struct ExpectServerDoneOrCertReq {
must_issue_new_ticket: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
if matches!(
m.payload,
MessagePayload::Handshake(HandshakeMessagePayload {
@@ -571,6 +589,7 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
must_issue_new_ticket: self.must_issue_new_ticket,
})
.handle(cx, m)
.await
} else {
self.transcript.abandon_client_auth();
@@ -589,6 +608,7 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
must_issue_new_ticket: self.must_issue_new_ticket,
})
.handle(cx, m)
.await
}
}
}
@@ -607,8 +627,9 @@ struct ExpectCertificateRequest {
must_issue_new_ticket: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificateRequest {
fn handle(
async fn handle(
mut self: Box<Self>,
_cx: &mut ClientContext<'_>,
m: Message,
@@ -629,9 +650,7 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
const NO_CONTEXT: Option<Vec<u8>> = None; // TLS 1.2 doesn't use a context.
let client_auth = ClientAuthDetails::resolve(
self.config
.client_auth_cert_resolver
.as_ref(),
self.config.client_auth_cert_resolver.as_ref(),
Some(&certreq.canames),
&certreq.sigschemes,
NO_CONTEXT,
@@ -669,8 +688,13 @@ struct ExpectServerDone {
must_issue_new_ticket: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectServerDone {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
match m.payload {
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::ServerHelloDone,
@@ -778,9 +802,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
let mut transcript = st.transcript;
emit_clientkx(&mut transcript, cx.common, &kx.pubkey);
// nb. EMS handshake hash only runs up to ClientKeyExchange.
let ems_seed = st
.using_ems
.then(|| transcript.get_current_hash());
let ems_seed = st.using_ems.then(|| transcript.get_current_hash());
// 5c.
if let Some(ClientAuthDetails::Verify { signer, .. }) = &st.client_auth {
@@ -804,11 +826,8 @@ impl State<ClientConnectionData> for ExpectServerDone {
&secrets.randoms.client,
&secrets.master_secret,
);
cx.common
.start_encryption_tls12(&secrets, Side::Client);
cx.common
.record_layer
.start_encrypting();
cx.common.start_encryption_tls12(&secrets, Side::Client);
cx.common.record_layer.start_encrypting();
// 6.
emit_finished(&secrets, &mut transcript, cx.common);
@@ -857,8 +876,9 @@ struct ExpectNewTicket {
sig_verified: verify::HandshakeSignatureValid,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectNewTicket {
fn handle(
async fn handle(
mut self: Box<Self>,
_cx: &mut ClientContext<'_>,
m: Message,
@@ -902,8 +922,13 @@ struct ExpectCcs {
sig_verified: verify::HandshakeSignatureValid,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCcs {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
match m.payload {
MessagePayload::ChangeCipherSpec(..) => {}
payload => {
@@ -918,9 +943,7 @@ impl State<ClientConnectionData> for ExpectCcs {
cx.common.check_aligned_handshake()?;
// nb. msgs layer validates trivial contents of CCS
cx.common
.record_layer
.start_decrypting();
cx.common.record_layer.start_decrypting();
Ok(Box::new(ExpectFinished {
config: self.config,
@@ -987,10 +1010,7 @@ impl ExpectFinished {
self.session_id,
ticket,
self.secrets.get_master_secret(),
cx.common
.peer_certificates
.clone()
.unwrap_or_default(),
cx.common.peer_certificates.clone().unwrap_or_default(),
time_now,
lifetime,
self.using_ems,
@@ -1009,8 +1029,13 @@ impl ExpectFinished {
}
}
#[async_trait]
impl State<ClientConnectionData> for ExpectFinished {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
let mut st = *self;
let finished =
require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
@@ -1026,8 +1051,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let _fin_verified =
constant_time::verify_slices_are_equal(&expect_verify_data, &finished.0)
.map_err(|_| {
cx.common
.send_fatal_alert(AlertDescription::DecryptError);
cx.common.send_fatal_alert(AlertDescription::DecryptError);
Error::DecryptError
})
.map(|_| verify::FinishedMessageVerified::assertion())?;
@@ -1039,9 +1063,7 @@ impl State<ClientConnectionData> for ExpectFinished {
if st.resuming {
emit_ccs(cx.common);
cx.common
.record_layer
.start_encrypting();
cx.common.record_layer.start_encrypting();
emit_finished(&st.secrets, &mut st.transcript, cx.common);
}
@@ -1063,12 +1085,15 @@ struct ExpectTraffic {
_fin_verified: verify::FinishedMessageVerified,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectTraffic {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
match m.payload {
MessagePayload::ApplicationData(payload) => cx
.common
.take_received_plaintext(payload),
MessagePayload::ApplicationData(payload) => cx.common.take_received_plaintext(payload),
payload => {
return Err(inappropriate_message(
&payload,
@@ -1085,8 +1110,7 @@ impl State<ClientConnectionData> for ExpectTraffic {
label: &[u8],
context: Option<&[u8]>,
) -> Result<(), Error> {
self.secrets
.export_keying_material(output, label, context);
self.secrets.export_keying_material(output, label, context);
Ok(())
}
}

View File

@@ -38,6 +38,7 @@ use crate::ticketer::TimeBase;
use ring::constant_time;
use crate::sign::{CertifiedKey, Signer};
use async_trait::async_trait;
use std::sync::Arc;
// Extensions we expect in plaintext in the ServerHello.
@@ -56,9 +57,9 @@ static DISALLOWED_TLS13_EXTS: &[ExtensionType] = &[
ExtensionType::ExtendedMasterSecret,
];
pub(super) fn handle_server_hello(
pub(super) async fn handle_server_hello(
config: Arc<ClientConfig>,
cx: &mut ClientContext,
cx: &mut ClientContext<'_>,
server_hello: &ServerHelloPayload,
mut resuming_session: Option<persist::Tls13ClientSessionValue>,
server_name: ServerName,
@@ -351,8 +352,13 @@ struct ExpectEncryptedExtensions {
hello: ClientHelloDetails,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectEncryptedExtensions {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
let exts = require_handshake_msg!(
m,
HandshakeType::EncryptedExtensions,
@@ -427,36 +433,47 @@ struct ExpectCertificateOrCertReq {
may_send_sct_list: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificateOrCertReq {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
match m.payload {
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::CertificateTLS13(..),
..
}) => Box::new(ExpectCertificate {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
key_schedule: self.key_schedule,
may_send_sct_list: self.may_send_sct_list,
client_auth: None,
})
.handle(cx, m),
}) => {
Box::new(ExpectCertificate {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
key_schedule: self.key_schedule,
may_send_sct_list: self.may_send_sct_list,
client_auth: None,
})
.handle(cx, m)
.await
}
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::CertificateRequestTLS13(..),
..
}) => Box::new(ExpectCertificateRequest {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
key_schedule: self.key_schedule,
may_send_sct_list: self.may_send_sct_list,
})
.handle(cx, m),
}) => {
Box::new(ExpectCertificateRequest {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
key_schedule: self.key_schedule,
may_send_sct_list: self.may_send_sct_list,
})
.handle(cx, m)
.await
}
payload => Err(inappropriate_handshake_message(
&payload,
&[ContentType::Handshake],
@@ -482,8 +499,13 @@ struct ExpectCertificateRequest {
may_send_sct_list: bool,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificateRequest {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
let certreq = &require_handshake_msg!(
m,
HandshakeType::CertificateRequest,
@@ -551,8 +573,13 @@ struct ExpectCertificate {
client_auth: Option<ClientAuthDetails>,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificate {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
let cert_chain = require_handshake_msg!(
m,
HandshakeType::Certificate,
@@ -621,8 +648,13 @@ struct ExpectCertificateVerify {
client_auth: Option<ClientAuthDetails>,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectCertificateVerify {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
let cert_verify = require_handshake_msg!(
m,
HandshakeType::CertificateVerify,
@@ -780,8 +812,13 @@ struct ExpectFinished {
sig_verified: verify::HandshakeSignatureValid,
}
#[async_trait]
impl State<ClientConnectionData> for ExpectFinished {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
let mut st = *self;
let finished =
require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
@@ -892,7 +929,7 @@ struct ExpectTraffic {
impl ExpectTraffic {
#[allow(clippy::unnecessary_wraps)]
fn handle_new_ticket_tls13(
async fn handle_new_ticket_tls13(
&mut self,
cx: &mut ClientContext<'_>,
nst: &NewSessionTicketPayloadTLS13,
@@ -944,7 +981,7 @@ impl ExpectTraffic {
Ok(())
}
fn handle_key_update(
async fn handle_key_update(
&mut self,
common: &mut CommonState,
kur: &KeyUpdateRequest,
@@ -973,18 +1010,23 @@ impl ExpectTraffic {
}
}
#[async_trait]
impl State<ClientConnectionData> for ExpectTraffic {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
async fn handle(
mut self: Box<Self>,
cx: &mut ClientContext<'_>,
m: Message,
) -> hs::NextStateOrError {
match m.payload {
MessagePayload::ApplicationData(payload) => cx.common.take_received_plaintext(payload),
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::NewSessionTicketTLS13(ref new_ticket),
..
}) => self.handle_new_ticket_tls13(cx, new_ticket)?,
}) => self.handle_new_ticket_tls13(cx, new_ticket).await?,
MessagePayload::Handshake(HandshakeMessagePayload {
payload: HandshakePayload::KeyUpdate(ref key_update),
..
}) => self.handle_key_update(cx.common, key_update)?,
}) => self.handle_key_update(cx.common, key_update).await?,
payload => {
return Err(inappropriate_handshake_message(
&payload,

View File

@@ -1,3 +1,4 @@
use crate::client::ClientConnectionData;
use crate::error::Error;
use crate::key;
#[cfg(feature = "logging")]
@@ -19,11 +20,16 @@ use crate::suites::SupportedCipherSuite;
use crate::tls12::ConnectionSecrets;
use crate::vecbuf::ChunkVecBuffer;
use async_trait::async_trait;
use std::collections::VecDeque;
use std::convert::TryFrom;
use std::io;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task;
use tokio::io::AsyncWrite;
use tokio_util::sync::ReusableBoxFuture;
/// Values of this structure are returned from [`Connection::process_new_packets`]
/// and tell the caller the current I/O state of the TLS connection.
@@ -150,36 +156,10 @@ impl<'a> io::Read for Reader<'a> {
Ok(())
}
}
/// Internal trait implemented by the [`ServerConnection`]/[`ClientConnection`]
/// allowing them to be the subject of a [`Writer`].
pub trait PlaintextSink {
fn write(&mut self, buf: &[u8]) -> io::Result<usize>;
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize>;
fn flush(&mut self) -> io::Result<()>;
}
impl<T> PlaintextSink for ConnectionCommon<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(self.send_some_plaintext(buf))
}
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
let mut sz = 0;
for buf in bufs {
sz += self.send_some_plaintext(buf);
}
Ok(sz)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
/// A structure that implements [`std::io::Write`] for writing plaintext.
/// A structure that implements [`tokio::io::AsyncWrite`] for writing plaintext.
pub struct Writer<'a> {
sink: &'a mut dyn PlaintextSink,
con: &'a mut ConnectionCommon,
fut: Option<ReusableBoxFuture<'a, usize>>,
}
impl<'a> Writer<'a> {
@@ -188,12 +168,12 @@ impl<'a> Writer<'a> {
/// This is not an external interface. Get one of these objects
/// from [`Connection::writer`].
#[doc(hidden)]
pub fn new(sink: &'a mut dyn PlaintextSink) -> Writer<'a> {
Writer { sink }
pub fn new(con: &'a mut ConnectionCommon) -> Writer<'a> {
Writer { con, fut: None }
}
}
impl<'a> io::Write for Writer<'a> {
impl<'a> AsyncWrite for Writer<'a> {
/// Send the plaintext `buf` to the peer, encrypting
/// and authenticating it. Once this function succeeds
/// you should call [`CommonState::write_tls`] which will output the
@@ -203,16 +183,33 @@ impl<'a> io::Write for Writer<'a> {
/// TLS handshake completes, and sends it as soon
/// as it can. See [`CommonState::set_buffer_limit`] to control
/// the size of this buffer.
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sink.write(buf)
fn poll_write(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let fut = match self.fut.as_mut() {
Some(fut) => fut,
None => {
let fut = self.con.send_some_plaintext(buf);
self.fut.get_or_insert(ReusableBoxFuture::new(fut))
}
};
match fut.poll(cx) {
task::Poll::Ready(sz) => task::Poll::Ready(Ok(sz)),
task::Poll::Pending => task::Poll::Pending,
}
}
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
self.sink.write_vectored(bufs)
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
task::Poll::Ready(Ok(()))
}
fn flush(&mut self) -> io::Result<()> {
self.sink.flush()
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), io::Error>> {
task::Poll::Ready(Ok(()))
}
}
@@ -255,16 +252,20 @@ enum Limit {
}
/// Interface shared by client and server connections.
pub struct ConnectionCommon<Data> {
state: Result<Box<dyn State<Data>>, Error>,
pub(crate) data: Data,
pub struct ConnectionCommon {
state: Result<Box<dyn State<ClientConnectionData>>, Error>,
pub(crate) data: ClientConnectionData,
pub(crate) common_state: CommonState,
message_deframer: MessageDeframer,
handshake_joiner: HandshakeJoiner,
}
impl<Data> ConnectionCommon<Data> {
pub(crate) fn new(state: Box<dyn State<Data>>, data: Data, common_state: CommonState) -> Self {
impl ConnectionCommon {
pub(crate) fn new(
state: Box<dyn State<ClientConnectionData>>,
data: ClientConnectionData,
common_state: CommonState,
) -> Self {
Self {
state: Ok(state),
data,
@@ -320,7 +321,7 @@ impl<Data> ConnectionCommon<Data> {
/// [`write_tls`]: CommonState::write_tls
/// [`read_tls`]: ConnectionCommon::read_tls
/// [`process_new_packets`]: ConnectionCommon::process_new_packets
pub fn complete_io<T>(&mut self, io: &mut T) -> Result<(usize, usize), io::Error>
pub async fn complete_io<T>(&mut self, io: &mut T) -> Result<(usize, usize), io::Error>
where
Self: Sized,
T: io::Read + io::Write,
@@ -346,7 +347,7 @@ impl<Data> ConnectionCommon<Data> {
}
}
match self.process_new_packets() {
match self.process_new_packets().await {
Ok(_) => {}
Err(e) => {
// In case we have an alert to send describing this error,
@@ -396,15 +397,15 @@ impl<Data> ConnectionCommon<Data> {
Ok(self.handshake_joiner.frames.pop_front())
}
pub(crate) fn replace_state(&mut self, new: Box<dyn State<Data>>) {
pub(crate) fn replace_state(&mut self, new: Box<dyn State<ClientConnectionData>>) {
self.state = Ok(new);
}
fn process_msg(
async fn process_msg(
&mut self,
msg: OpaqueMessage,
state: Box<dyn State<Data>>,
) -> Result<Box<dyn State<Data>>, Error> {
state: Box<dyn State<ClientConnectionData>>,
) -> Result<Box<dyn State<ClientConnectionData>>, Error> {
// Drop CCS messages during handshake in TLS1.3
if msg.typ == ContentType::ChangeCipherSpec
&& !self.common_state.may_receive_application_data
@@ -430,7 +431,7 @@ impl<Data> ConnectionCommon<Data> {
// Decrypt if demanded by current state.
let msg = match self.common_state.record_layer.is_decrypting() {
true => match self.common_state.decrypt_incoming(msg) {
true => match self.common_state.decrypt_incoming(msg).await {
Ok(None) => {
// message dropped
return Ok(state);
@@ -454,7 +455,7 @@ impl<Data> ConnectionCommon<Data> {
.send_fatal_alert(AlertDescription::DecodeError);
Error::CorruptMessagePayload(ContentType::Handshake)
})?;
return self.process_new_handshake_messages(state);
return self.process_new_handshake_messages(state).await;
}
// Now we can fully parse the message payload.
@@ -468,6 +469,7 @@ impl<Data> ConnectionCommon<Data> {
self.common_state
.process_main_protocol(msg, state, &mut self.data)
.await
}
/// Processes any new packets read by a previous call to
@@ -488,7 +490,7 @@ impl<Data> ConnectionCommon<Data> {
///
/// [`read_tls`]: Connection::read_tls
/// [`process_new_packets`]: Connection::process_new_packets
pub fn process_new_packets(&mut self) -> Result<IoState, Error> {
pub async fn process_new_packets(&mut self) -> Result<IoState, Error> {
let mut state = match mem::replace(&mut self.state, Err(Error::HandshakeNotComplete)) {
Ok(state) => state,
Err(e) => {
@@ -502,7 +504,7 @@ impl<Data> ConnectionCommon<Data> {
}
while let Some(msg) = self.message_deframer.frames.pop_front() {
match self.process_msg(msg, state) {
match self.process_msg(msg, state).await {
Ok(new) => state = new,
Err(e) => {
self.state = Err(e.clone());
@@ -515,25 +517,26 @@ impl<Data> ConnectionCommon<Data> {
Ok(self.common_state.current_io_state())
}
fn process_new_handshake_messages(
async fn process_new_handshake_messages(
&mut self,
mut state: Box<dyn State<Data>>,
) -> Result<Box<dyn State<Data>>, Error> {
mut state: Box<dyn State<ClientConnectionData>>,
) -> Result<Box<dyn State<ClientConnectionData>>, Error> {
self.common_state.aligned_handshake = self.handshake_joiner.is_empty();
while let Some(msg) = self.handshake_joiner.frames.pop_front() {
state = self
.common_state
.process_main_protocol(msg, state, &mut self.data)?;
.process_main_protocol(msg, state, &mut self.data)
.await?;
}
Ok(state)
}
pub(crate) fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
pub(crate) async fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
if let Ok(st) = &mut self.state {
st.perhaps_write_key_update(&mut self.common_state);
}
self.common_state.send_some_plaintext(buf)
self.common_state.send_some_plaintext(buf).await
}
/// Read TLS content from `rd`. This method does internal
@@ -585,7 +588,7 @@ impl<Data> ConnectionCommon<Data> {
}
}
impl<T> Deref for ConnectionCommon<T> {
impl Deref for ConnectionCommon {
type Target = CommonState;
fn deref(&self) -> &Self::Target {
@@ -593,7 +596,7 @@ impl<T> Deref for ConnectionCommon<T> {
}
}
impl<T> DerefMut for ConnectionCommon<T> {
impl DerefMut for ConnectionCommon {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.common_state
}
@@ -713,12 +716,12 @@ impl CommonState {
matches!(self.negotiated_version, Some(ProtocolVersion::TLSv1_3))
}
fn process_main_protocol<Data>(
async fn process_main_protocol(
&mut self,
msg: Message,
mut state: Box<dyn State<Data>>,
data: &mut Data,
) -> Result<Box<dyn State<Data>>, Error> {
mut state: Box<dyn State<ClientConnectionData>>,
data: &mut ClientConnectionData,
) -> Result<Box<dyn State<ClientConnectionData>>, Error> {
// For TLS1.2, outside of the handshake, send rejection alerts for
// renegotiation requests. These can occur any time.
if self.may_receive_application_data && !self.is_tls13() {
@@ -733,7 +736,7 @@ impl CommonState {
}
let mut cx = Context { common: self, data };
match state.handle(&mut cx, msg) {
match state.handle(&mut cx, msg).await {
Ok(next) => {
state = next;
Ok(state)
@@ -752,11 +755,11 @@ impl CommonState {
///
/// If internal buffers are too small, this function will not accept
/// all the data.
pub(crate) fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
self.send_plain(data, Limit::Yes)
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
self.send_plain(data, Limit::Yes).await
}
pub(crate) fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
pub(crate) async fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
debug_assert!(self.early_traffic);
debug_assert!(self.record_layer.is_encrypting());
@@ -765,7 +768,7 @@ impl CommonState {
return 0;
}
self.send_appdata_encrypt(data, Limit::Yes)
self.send_appdata_encrypt(data, Limit::Yes).await
}
// Changing the keys must not span any fragmented handshake
@@ -788,7 +791,7 @@ impl CommonState {
Error::PeerMisbehavedError(why.to_string())
}
pub(crate) fn decrypt_incoming(
pub(crate) async fn decrypt_incoming(
&mut self,
encr: OpaqueMessage,
) -> Result<Option<PlainMessage>, Error> {
@@ -797,7 +800,7 @@ impl CommonState {
}
let encrypted_len = encr.payload.0.len();
let plain = self.record_layer.decrypt_incoming(encr);
let plain = self.record_layer.decrypt_incoming(encr).await;
match plain {
Err(Error::PeerSentOversizedRecord) => {
@@ -819,17 +822,17 @@ impl CommonState {
/// Fragment `m`, encrypt the fragments, and then queue
/// the encrypted fragments for sending.
pub(crate) fn send_msg_encrypt(&mut self, m: PlainMessage) {
pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) {
let mut plain_messages = VecDeque::new();
self.message_fragmenter.fragment(m, &mut plain_messages);
for m in plain_messages {
self.send_single_fragment(m.borrow());
self.send_single_fragment(m.borrow()).await;
}
}
/// Like send_msg_encrypt, but operate on an appdata directly.
fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
// Here, the limit on sendable_tls applies to encrypted data,
// but we're respecting it for plaintext data -- so we'll
// be out by whatever the cipher+record overhead is. That's a
@@ -848,13 +851,13 @@ impl CommonState {
);
for m in plain_messages {
self.send_single_fragment(m);
self.send_single_fragment(m).await;
}
len
}
fn send_single_fragment(&mut self, m: BorrowedPlainMessage) {
async fn send_single_fragment<'a>(&mut self, m: BorrowedPlainMessage<'a>) {
// Close connection once we start to run out of
// sequence space.
if self.record_layer.wants_close_before_encrypt() {
@@ -867,7 +870,7 @@ impl CommonState {
return;
}
let em = self.record_layer.encrypt_outgoing(m);
let em = self.record_layer.encrypt_outgoing(m).await;
self.queue_tls_message(em);
}
@@ -887,7 +890,7 @@ impl CommonState {
///
/// Returns the number of bytes written from `data`: this might
/// be less than `data.len()` if buffer limits were exceeded.
fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
if !self.may_send_application_data {
// If we haven't completed handshaking, buffer
// plaintext to send once we do.
@@ -905,7 +908,7 @@ impl CommonState {
return 0;
}
self.send_appdata_encrypt(data, limit)
self.send_appdata_encrypt(data, limit).await
}
pub(crate) fn start_outgoing_traffic(&mut self) {
@@ -1092,12 +1095,13 @@ impl CommonState {
}
}
pub(crate) trait State<Data>: Send + Sync {
fn handle(
#[async_trait]
pub(crate) trait State<ClientConnectionData>: Send + Sync {
async fn handle(
self: Box<Self>,
cx: &mut Context<'_, Data>,
cx: &mut Context<'_>,
message: Message,
) -> Result<Box<dyn State<Data>>, Error>;
) -> Result<Box<dyn State<ClientConnectionData>>, Error>;
fn export_keying_material(
&self,
@@ -1111,9 +1115,9 @@ pub(crate) trait State<Data>: Send + Sync {
fn perhaps_write_key_update(&mut self, _cx: &mut CommonState) {}
}
pub(crate) struct Context<'a, Data> {
pub(crate) struct Context<'a> {
pub(crate) common: &'a mut CommonState,
pub(crate) data: &'a mut Data,
pub(crate) data: &'a mut ClientConnectionData,
}
#[derive(Clone, Copy, Debug, PartialEq)]

View File

@@ -318,7 +318,7 @@ mod hash_hs;
mod limited_cache;
mod rand;
mod record_layer;
mod stream;
//mod stream;
#[cfg(feature = "tls12")]
mod tls12;
mod tls13;
@@ -367,7 +367,7 @@ pub use crate::msgs::enums::CipherSuite;
pub use crate::msgs::enums::ProtocolVersion;
pub use crate::msgs::enums::SignatureScheme;
pub use crate::msgs::handshake::DistinguishedNames;
pub use crate::stream::{Stream, StreamOwned};
//pub use crate::stream::{Stream, StreamOwned};
pub use crate::suites::{
BulkAlgorithm, SupportedCipherSuite, ALL_CIPHER_SUITES, DEFAULT_CIPHER_SUITES,
};

View File

@@ -156,12 +156,13 @@ impl RecordLayer {
/// `encr` is a decoded message allegedly received from the peer.
/// If it can be decrypted, its decryption is returned. Otherwise,
/// an error is returned.
pub(crate) fn decrypt_incoming(&mut self, encr: OpaqueMessage) -> Result<PlainMessage, Error> {
pub(crate) async fn decrypt_incoming(
&mut self,
encr: OpaqueMessage,
) -> Result<PlainMessage, Error> {
debug_assert!(self.is_decrypting());
let seq = self.read_seq;
let msg = self
.message_decrypter
.decrypt(encr, seq)?;
let msg = self.message_decrypter.decrypt(encr, seq)?;
self.read_seq += 1;
Ok(msg)
}
@@ -170,13 +171,14 @@ impl RecordLayer {
///
/// `plain` is a TLS message we'd like to send. This function
/// panics if the requisite keying material hasn't been established yet.
pub(crate) fn encrypt_outgoing(&mut self, plain: BorrowedPlainMessage) -> OpaqueMessage {
pub(crate) async fn encrypt_outgoing<'a>(
&mut self,
plain: BorrowedPlainMessage<'a>,
) -> OpaqueMessage {
debug_assert!(self.encrypt_state == DirectionState::Active);
assert!(!self.encrypt_exhausted());
let seq = self.write_seq;
self.write_seq += 1;
self.message_encrypter
.encrypt(plain, seq)
.unwrap()
self.message_encrypter.encrypt(plain, seq).unwrap()
}
}