mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 21:38:00 -05:00
refactoring client for async
This commit is contained in:
@@ -16,9 +16,13 @@ name = "tls_aio"
|
||||
rustversion = { version = "1.0.6", optional = true }
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.53"
|
||||
futures = "0.3.21"
|
||||
log = { version = "0.4.4", optional = true }
|
||||
ring = "0.16.20"
|
||||
sct = "0.7.0"
|
||||
tokio = "1.18.2"
|
||||
tokio-util = "0.7.2"
|
||||
webpki = { version = "0.22.0", features = ["alloc", "std"] }
|
||||
|
||||
[features]
|
||||
|
||||
@@ -18,10 +18,12 @@ use super::hs;
|
||||
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error as StdError;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{fmt, io, mem};
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
/// A trait for the ability to store client session data.
|
||||
/// The keys and values are opaque.
|
||||
@@ -372,11 +374,12 @@ impl EarlyData {
|
||||
/// Stub that implements io::Write and dispatches to `write_early_data`.
|
||||
pub struct WriteEarlyData<'a> {
|
||||
sess: &'a mut ClientConnection,
|
||||
fut: Option<futures::future::BoxFuture<'a, io::Result<usize>>>,
|
||||
}
|
||||
|
||||
impl<'a> WriteEarlyData<'a> {
|
||||
fn new(sess: &'a mut ClientConnection) -> WriteEarlyData<'a> {
|
||||
WriteEarlyData { sess }
|
||||
WriteEarlyData { sess, fut: None }
|
||||
}
|
||||
|
||||
/// How many bytes you may send. Writes will become short
|
||||
@@ -386,19 +389,31 @@ impl<'a> WriteEarlyData<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> io::Write for WriteEarlyData<'a> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.sess.write_early_data(buf)
|
||||
impl<'a> AsyncWrite for WriteEarlyData<'a> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
let fut = match self.fut {
|
||||
Some(fut) => fut,
|
||||
None => Box::pin(self.sess.write_early_data(buf)),
|
||||
};
|
||||
fut.as_mut().poll(cx)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
/// This represents a single TLS client connection.
|
||||
pub struct ClientConnection {
|
||||
inner: ConnectionCommon<ClientConnectionData>,
|
||||
inner: ConnectionCommon,
|
||||
}
|
||||
|
||||
impl fmt::Debug for ClientConnection {
|
||||
@@ -411,11 +426,11 @@ impl ClientConnection {
|
||||
/// Make a new ClientConnection. `config` controls how
|
||||
/// we behave in the TLS protocol, `name` is the
|
||||
/// name of the server we want to talk to.
|
||||
pub fn new(config: Arc<ClientConfig>, name: ServerName) -> Result<Self, Error> {
|
||||
Self::new_inner(config, name, Vec::new(), Protocol::Tcp)
|
||||
pub async fn new(config: Arc<ClientConfig>, name: ServerName) -> Result<Self, Error> {
|
||||
Self::new_inner(config, name, Vec::new(), Protocol::Tcp).await
|
||||
}
|
||||
|
||||
fn new_inner(
|
||||
async fn new_inner(
|
||||
config: Arc<ClientConfig>,
|
||||
name: ServerName,
|
||||
extra_exts: Vec<ClientExtension>,
|
||||
@@ -430,7 +445,7 @@ impl ClientConnection {
|
||||
data: &mut data,
|
||||
};
|
||||
|
||||
let state = hs::start_handshake(name, extra_exts, config, &mut cx)?;
|
||||
let state = hs::start_handshake(name, extra_exts, config, &mut cx).await?;
|
||||
let inner = ConnectionCommon::new(state, data, common_state);
|
||||
|
||||
Ok(Self { inner })
|
||||
@@ -471,17 +486,23 @@ impl ClientConnection {
|
||||
self.inner.data.early_data.is_accepted()
|
||||
}
|
||||
|
||||
fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
|
||||
self.inner
|
||||
.data
|
||||
.early_data
|
||||
.check_write(data.len())
|
||||
.map(|sz| self.inner.common_state.send_early_plaintext(&data[..sz]))
|
||||
async fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
|
||||
let sz = self.inner.data.early_data.check_write(data.len());
|
||||
|
||||
if let Ok(sz) = sz {
|
||||
Ok(self
|
||||
.inner
|
||||
.common_state
|
||||
.send_early_plaintext(&data[..sz])
|
||||
.await)
|
||||
} else {
|
||||
sz
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ClientConnection {
|
||||
type Target = ConnectionCommon<ClientConnectionData>;
|
||||
type Target = ConnectionCommon;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
|
||||
@@ -33,11 +33,12 @@ use crate::client::client_conn::ClientConnectionData;
|
||||
use crate::client::common::ClientHelloDetails;
|
||||
use crate::client::{tls13, ClientConfig, ServerName};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(super) type NextState = Box<dyn State<ClientConnectionData>>;
|
||||
pub(super) type NextStateOrError = Result<NextState, Error>;
|
||||
pub(super) type ClientContext<'a> = crate::conn::Context<'a, ClientConnectionData>;
|
||||
pub(super) type ClientContext<'a> = crate::conn::Context<'a>;
|
||||
|
||||
fn find_session(
|
||||
server_name: &ServerName,
|
||||
@@ -68,7 +69,7 @@ fn find_session(
|
||||
.and_then(|resuming| Some(resuming))
|
||||
}
|
||||
|
||||
pub(super) fn start_handshake(
|
||||
pub(super) async fn start_handshake(
|
||||
server_name: ServerName,
|
||||
extra_exts: Vec<ClientExtension>,
|
||||
config: Arc<ClientConfig>,
|
||||
@@ -132,7 +133,8 @@ pub(super) fn start_handshake(
|
||||
extra_exts,
|
||||
may_send_sct_list,
|
||||
None,
|
||||
))
|
||||
)
|
||||
.await)
|
||||
}
|
||||
|
||||
struct ExpectServerHello {
|
||||
@@ -155,7 +157,7 @@ struct ExpectServerHelloOrHelloRetryRequest {
|
||||
extra_exts: Vec<ClientExtension>,
|
||||
}
|
||||
|
||||
fn emit_client_hello_for_retry(
|
||||
async fn emit_client_hello_for_retry(
|
||||
config: Arc<ClientConfig>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
|
||||
@@ -403,8 +405,13 @@ pub(super) fn sct_list_is_invalid(scts: &SCTList) -> bool {
|
||||
scts.is_empty() || scts.iter().any(|sct| sct.0.is_empty())
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectServerHello {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> NextStateOrError {
|
||||
let server_hello =
|
||||
require_handshake_msg!(m, HandshakeType::ServerHello, HandshakePayload::ServerHello)?;
|
||||
trace!("We got ServerHello {:#?}", server_hello);
|
||||
@@ -552,6 +559,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
self.offered_key_share.unwrap(),
|
||||
self.sent_tls13_fake_ccs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(feature = "tls12")]
|
||||
SupportedCipherSuite::Tls12(suite) => {
|
||||
@@ -571,6 +579,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
transcript,
|
||||
}
|
||||
.handle_server_hello(cx, suite, server_hello, tls13_supported)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -581,7 +590,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
Box::new(self.next)
|
||||
}
|
||||
|
||||
fn handle_hello_retry_request(
|
||||
async fn handle_hello_retry_request(
|
||||
self,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
@@ -707,21 +716,23 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
self.extra_exts,
|
||||
may_send_sct_list,
|
||||
Some(cs),
|
||||
))
|
||||
)
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
|
||||
async fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
|
||||
match m.payload {
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::ServerHello(..),
|
||||
..
|
||||
}) => self.into_expect_server_hello().handle(cx, m),
|
||||
}) => self.into_expect_server_hello().handle(cx, m).await,
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::HelloRetryRequest(..),
|
||||
..
|
||||
}) => self.handle_hello_retry_request(cx, m),
|
||||
}) => self.handle_hello_retry_request(cx, m).await,
|
||||
payload => Err(inappropriate_handshake_message(
|
||||
&payload,
|
||||
&[ContentType::Handshake],
|
||||
|
||||
@@ -29,6 +29,7 @@ use crate::client::{hs, ClientConfig, ServerName};
|
||||
use ring::agreement::PublicKey;
|
||||
use ring::constant_time;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(super) use server_hello::CompleteServerHelloHandling;
|
||||
@@ -50,16 +51,14 @@ mod server_hello {
|
||||
}
|
||||
|
||||
impl CompleteServerHelloHandling {
|
||||
pub(in crate::client) fn handle_server_hello(
|
||||
pub(in crate::client) async fn handle_server_hello(
|
||||
mut self,
|
||||
cx: &mut ClientContext,
|
||||
cx: &mut ClientContext<'_>,
|
||||
suite: &'static Tls12CipherSuite,
|
||||
server_hello: &ServerHelloPayload,
|
||||
tls13_supported: bool,
|
||||
) -> hs::NextStateOrError {
|
||||
server_hello
|
||||
.random
|
||||
.write_slice(&mut self.randoms.server);
|
||||
server_hello.random.write_slice(&mut self.randoms.server);
|
||||
|
||||
// Look for TLS1.3 downgrade signal in server random
|
||||
// both the server random and TLS12_DOWNGRADE_SENTINEL are
|
||||
@@ -132,8 +131,7 @@ mod server_hello {
|
||||
&secrets.randoms.client,
|
||||
&secrets.master_secret,
|
||||
);
|
||||
cx.common
|
||||
.start_encryption_tls12(&secrets, Side::Client);
|
||||
cx.common.start_encryption_tls12(&secrets, Side::Client);
|
||||
|
||||
// Since we're resuming, we verified the certificate and
|
||||
// proof of possession in the prior session.
|
||||
@@ -203,8 +201,9 @@ struct ExpectCertificate {
|
||||
server_cert_sct_list: Option<SCTList>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificate {
|
||||
fn handle(
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
_cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
@@ -264,46 +263,57 @@ struct ExpectCertificateStatusOrServerKx {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
match m.payload {
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::ServerKeyExchange(..),
|
||||
..
|
||||
}) => Box::new(ExpectServerKx {
|
||||
config: self.config,
|
||||
resuming_session: self.resuming_session,
|
||||
session_id: self.session_id,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
using_ems: self.using_ems,
|
||||
transcript: self.transcript,
|
||||
suite: self.suite,
|
||||
server_cert: ServerCertDetails::new(
|
||||
self.server_cert_chain,
|
||||
vec![],
|
||||
self.server_cert_sct_list,
|
||||
),
|
||||
must_issue_new_ticket: self.must_issue_new_ticket,
|
||||
})
|
||||
.handle(cx, m),
|
||||
}) => {
|
||||
Box::new(ExpectServerKx {
|
||||
config: self.config,
|
||||
resuming_session: self.resuming_session,
|
||||
session_id: self.session_id,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
using_ems: self.using_ems,
|
||||
transcript: self.transcript,
|
||||
suite: self.suite,
|
||||
server_cert: ServerCertDetails::new(
|
||||
self.server_cert_chain,
|
||||
vec![],
|
||||
self.server_cert_sct_list,
|
||||
),
|
||||
must_issue_new_ticket: self.must_issue_new_ticket,
|
||||
})
|
||||
.handle(cx, m)
|
||||
.await
|
||||
}
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::CertificateStatus(..),
|
||||
..
|
||||
}) => Box::new(ExpectCertificateStatus {
|
||||
config: self.config,
|
||||
resuming_session: self.resuming_session,
|
||||
session_id: self.session_id,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
using_ems: self.using_ems,
|
||||
transcript: self.transcript,
|
||||
suite: self.suite,
|
||||
server_cert_sct_list: self.server_cert_sct_list,
|
||||
server_cert_chain: self.server_cert_chain,
|
||||
must_issue_new_ticket: self.must_issue_new_ticket,
|
||||
})
|
||||
.handle(cx, m),
|
||||
}) => {
|
||||
Box::new(ExpectCertificateStatus {
|
||||
config: self.config,
|
||||
resuming_session: self.resuming_session,
|
||||
session_id: self.session_id,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
using_ems: self.using_ems,
|
||||
transcript: self.transcript,
|
||||
suite: self.suite,
|
||||
server_cert_sct_list: self.server_cert_sct_list,
|
||||
server_cert_chain: self.server_cert_chain,
|
||||
must_issue_new_ticket: self.must_issue_new_ticket,
|
||||
})
|
||||
.handle(cx, m)
|
||||
.await
|
||||
}
|
||||
payload => Err(inappropriate_handshake_message(
|
||||
&payload,
|
||||
&[ContentType::Handshake],
|
||||
@@ -330,8 +340,9 @@ struct ExpectCertificateStatus {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificateStatus {
|
||||
fn handle(
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
_cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
@@ -383,8 +394,13 @@ struct ExpectServerKx {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectServerKx {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
let opaque_kx = require_handshake_msg!(
|
||||
m,
|
||||
HandshakeType::ServerKeyExchange,
|
||||
@@ -392,13 +408,10 @@ impl State<ClientConnectionData> for ExpectServerKx {
|
||||
)?;
|
||||
self.transcript.add_message(&m);
|
||||
|
||||
let ecdhe = opaque_kx
|
||||
.unwrap_given_kxa(&self.suite.kx)
|
||||
.ok_or_else(|| {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
Error::CorruptMessagePayload(ContentType::Handshake)
|
||||
})?;
|
||||
let ecdhe = opaque_kx.unwrap_given_kxa(&self.suite.kx).ok_or_else(|| {
|
||||
cx.common.send_fatal_alert(AlertDescription::DecodeError);
|
||||
Error::CorruptMessagePayload(ContentType::Handshake)
|
||||
})?;
|
||||
|
||||
// Save the signature and signed parameters for later verification.
|
||||
let mut kx_params = Vec::new();
|
||||
@@ -548,8 +561,13 @@ struct ExpectServerDoneOrCertReq {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
if matches!(
|
||||
m.payload,
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
@@ -571,6 +589,7 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
|
||||
must_issue_new_ticket: self.must_issue_new_ticket,
|
||||
})
|
||||
.handle(cx, m)
|
||||
.await
|
||||
} else {
|
||||
self.transcript.abandon_client_auth();
|
||||
|
||||
@@ -589,6 +608,7 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
|
||||
must_issue_new_ticket: self.must_issue_new_ticket,
|
||||
})
|
||||
.handle(cx, m)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -607,8 +627,9 @@ struct ExpectCertificateRequest {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
fn handle(
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
_cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
@@ -629,9 +650,7 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
|
||||
const NO_CONTEXT: Option<Vec<u8>> = None; // TLS 1.2 doesn't use a context.
|
||||
let client_auth = ClientAuthDetails::resolve(
|
||||
self.config
|
||||
.client_auth_cert_resolver
|
||||
.as_ref(),
|
||||
self.config.client_auth_cert_resolver.as_ref(),
|
||||
Some(&certreq.canames),
|
||||
&certreq.sigschemes,
|
||||
NO_CONTEXT,
|
||||
@@ -669,8 +688,13 @@ struct ExpectServerDone {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectServerDone {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
match m.payload {
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::ServerHelloDone,
|
||||
@@ -778,9 +802,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
let mut transcript = st.transcript;
|
||||
emit_clientkx(&mut transcript, cx.common, &kx.pubkey);
|
||||
// nb. EMS handshake hash only runs up to ClientKeyExchange.
|
||||
let ems_seed = st
|
||||
.using_ems
|
||||
.then(|| transcript.get_current_hash());
|
||||
let ems_seed = st.using_ems.then(|| transcript.get_current_hash());
|
||||
|
||||
// 5c.
|
||||
if let Some(ClientAuthDetails::Verify { signer, .. }) = &st.client_auth {
|
||||
@@ -804,11 +826,8 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
&secrets.randoms.client,
|
||||
&secrets.master_secret,
|
||||
);
|
||||
cx.common
|
||||
.start_encryption_tls12(&secrets, Side::Client);
|
||||
cx.common
|
||||
.record_layer
|
||||
.start_encrypting();
|
||||
cx.common.start_encryption_tls12(&secrets, Side::Client);
|
||||
cx.common.record_layer.start_encrypting();
|
||||
|
||||
// 6.
|
||||
emit_finished(&secrets, &mut transcript, cx.common);
|
||||
@@ -857,8 +876,9 @@ struct ExpectNewTicket {
|
||||
sig_verified: verify::HandshakeSignatureValid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectNewTicket {
|
||||
fn handle(
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
_cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
@@ -902,8 +922,13 @@ struct ExpectCcs {
|
||||
sig_verified: verify::HandshakeSignatureValid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCcs {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
match m.payload {
|
||||
MessagePayload::ChangeCipherSpec(..) => {}
|
||||
payload => {
|
||||
@@ -918,9 +943,7 @@ impl State<ClientConnectionData> for ExpectCcs {
|
||||
cx.common.check_aligned_handshake()?;
|
||||
|
||||
// nb. msgs layer validates trivial contents of CCS
|
||||
cx.common
|
||||
.record_layer
|
||||
.start_decrypting();
|
||||
cx.common.record_layer.start_decrypting();
|
||||
|
||||
Ok(Box::new(ExpectFinished {
|
||||
config: self.config,
|
||||
@@ -987,10 +1010,7 @@ impl ExpectFinished {
|
||||
self.session_id,
|
||||
ticket,
|
||||
self.secrets.get_master_secret(),
|
||||
cx.common
|
||||
.peer_certificates
|
||||
.clone()
|
||||
.unwrap_or_default(),
|
||||
cx.common.peer_certificates.clone().unwrap_or_default(),
|
||||
time_now,
|
||||
lifetime,
|
||||
self.using_ems,
|
||||
@@ -1009,8 +1029,13 @@ impl ExpectFinished {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectFinished {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
let mut st = *self;
|
||||
let finished =
|
||||
require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
|
||||
@@ -1026,8 +1051,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
let _fin_verified =
|
||||
constant_time::verify_slices_are_equal(&expect_verify_data, &finished.0)
|
||||
.map_err(|_| {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecryptError);
|
||||
cx.common.send_fatal_alert(AlertDescription::DecryptError);
|
||||
Error::DecryptError
|
||||
})
|
||||
.map(|_| verify::FinishedMessageVerified::assertion())?;
|
||||
@@ -1039,9 +1063,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
|
||||
if st.resuming {
|
||||
emit_ccs(cx.common);
|
||||
cx.common
|
||||
.record_layer
|
||||
.start_encrypting();
|
||||
cx.common.record_layer.start_encrypting();
|
||||
emit_finished(&st.secrets, &mut st.transcript, cx.common);
|
||||
}
|
||||
|
||||
@@ -1063,12 +1085,15 @@ struct ExpectTraffic {
|
||||
_fin_verified: verify::FinishedMessageVerified,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectTraffic {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
match m.payload {
|
||||
MessagePayload::ApplicationData(payload) => cx
|
||||
.common
|
||||
.take_received_plaintext(payload),
|
||||
MessagePayload::ApplicationData(payload) => cx.common.take_received_plaintext(payload),
|
||||
payload => {
|
||||
return Err(inappropriate_message(
|
||||
&payload,
|
||||
@@ -1085,8 +1110,7 @@ impl State<ClientConnectionData> for ExpectTraffic {
|
||||
label: &[u8],
|
||||
context: Option<&[u8]>,
|
||||
) -> Result<(), Error> {
|
||||
self.secrets
|
||||
.export_keying_material(output, label, context);
|
||||
self.secrets.export_keying_material(output, label, context);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ use crate::ticketer::TimeBase;
|
||||
use ring::constant_time;
|
||||
|
||||
use crate::sign::{CertifiedKey, Signer};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
// Extensions we expect in plaintext in the ServerHello.
|
||||
@@ -56,9 +57,9 @@ static DISALLOWED_TLS13_EXTS: &[ExtensionType] = &[
|
||||
ExtensionType::ExtendedMasterSecret,
|
||||
];
|
||||
|
||||
pub(super) fn handle_server_hello(
|
||||
pub(super) async fn handle_server_hello(
|
||||
config: Arc<ClientConfig>,
|
||||
cx: &mut ClientContext,
|
||||
cx: &mut ClientContext<'_>,
|
||||
server_hello: &ServerHelloPayload,
|
||||
mut resuming_session: Option<persist::Tls13ClientSessionValue>,
|
||||
server_name: ServerName,
|
||||
@@ -351,8 +352,13 @@ struct ExpectEncryptedExtensions {
|
||||
hello: ClientHelloDetails,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectEncryptedExtensions {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
let exts = require_handshake_msg!(
|
||||
m,
|
||||
HandshakeType::EncryptedExtensions,
|
||||
@@ -427,36 +433,47 @@ struct ExpectCertificateOrCertReq {
|
||||
may_send_sct_list: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificateOrCertReq {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
match m.payload {
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::CertificateTLS13(..),
|
||||
..
|
||||
}) => Box::new(ExpectCertificate {
|
||||
config: self.config,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
suite: self.suite,
|
||||
transcript: self.transcript,
|
||||
key_schedule: self.key_schedule,
|
||||
may_send_sct_list: self.may_send_sct_list,
|
||||
client_auth: None,
|
||||
})
|
||||
.handle(cx, m),
|
||||
}) => {
|
||||
Box::new(ExpectCertificate {
|
||||
config: self.config,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
suite: self.suite,
|
||||
transcript: self.transcript,
|
||||
key_schedule: self.key_schedule,
|
||||
may_send_sct_list: self.may_send_sct_list,
|
||||
client_auth: None,
|
||||
})
|
||||
.handle(cx, m)
|
||||
.await
|
||||
}
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::CertificateRequestTLS13(..),
|
||||
..
|
||||
}) => Box::new(ExpectCertificateRequest {
|
||||
config: self.config,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
suite: self.suite,
|
||||
transcript: self.transcript,
|
||||
key_schedule: self.key_schedule,
|
||||
may_send_sct_list: self.may_send_sct_list,
|
||||
})
|
||||
.handle(cx, m),
|
||||
}) => {
|
||||
Box::new(ExpectCertificateRequest {
|
||||
config: self.config,
|
||||
server_name: self.server_name,
|
||||
randoms: self.randoms,
|
||||
suite: self.suite,
|
||||
transcript: self.transcript,
|
||||
key_schedule: self.key_schedule,
|
||||
may_send_sct_list: self.may_send_sct_list,
|
||||
})
|
||||
.handle(cx, m)
|
||||
.await
|
||||
}
|
||||
payload => Err(inappropriate_handshake_message(
|
||||
&payload,
|
||||
&[ContentType::Handshake],
|
||||
@@ -482,8 +499,13 @@ struct ExpectCertificateRequest {
|
||||
may_send_sct_list: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
let certreq = &require_handshake_msg!(
|
||||
m,
|
||||
HandshakeType::CertificateRequest,
|
||||
@@ -551,8 +573,13 @@ struct ExpectCertificate {
|
||||
client_auth: Option<ClientAuthDetails>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificate {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
let cert_chain = require_handshake_msg!(
|
||||
m,
|
||||
HandshakeType::Certificate,
|
||||
@@ -621,8 +648,13 @@ struct ExpectCertificateVerify {
|
||||
client_auth: Option<ClientAuthDetails>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectCertificateVerify {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
let cert_verify = require_handshake_msg!(
|
||||
m,
|
||||
HandshakeType::CertificateVerify,
|
||||
@@ -780,8 +812,13 @@ struct ExpectFinished {
|
||||
sig_verified: verify::HandshakeSignatureValid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectFinished {
|
||||
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
let mut st = *self;
|
||||
let finished =
|
||||
require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
|
||||
@@ -892,7 +929,7 @@ struct ExpectTraffic {
|
||||
|
||||
impl ExpectTraffic {
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn handle_new_ticket_tls13(
|
||||
async fn handle_new_ticket_tls13(
|
||||
&mut self,
|
||||
cx: &mut ClientContext<'_>,
|
||||
nst: &NewSessionTicketPayloadTLS13,
|
||||
@@ -944,7 +981,7 @@ impl ExpectTraffic {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_key_update(
|
||||
async fn handle_key_update(
|
||||
&mut self,
|
||||
common: &mut CommonState,
|
||||
kur: &KeyUpdateRequest,
|
||||
@@ -973,18 +1010,23 @@ impl ExpectTraffic {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl State<ClientConnectionData> for ExpectTraffic {
|
||||
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
m: Message,
|
||||
) -> hs::NextStateOrError {
|
||||
match m.payload {
|
||||
MessagePayload::ApplicationData(payload) => cx.common.take_received_plaintext(payload),
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::NewSessionTicketTLS13(ref new_ticket),
|
||||
..
|
||||
}) => self.handle_new_ticket_tls13(cx, new_ticket)?,
|
||||
}) => self.handle_new_ticket_tls13(cx, new_ticket).await?,
|
||||
MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
payload: HandshakePayload::KeyUpdate(ref key_update),
|
||||
..
|
||||
}) => self.handle_key_update(cx.common, key_update)?,
|
||||
}) => self.handle_key_update(cx.common, key_update).await?,
|
||||
payload => {
|
||||
return Err(inappropriate_handshake_message(
|
||||
&payload,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::client::ClientConnectionData;
|
||||
use crate::error::Error;
|
||||
use crate::key;
|
||||
#[cfg(feature = "logging")]
|
||||
@@ -19,11 +20,16 @@ use crate::suites::SupportedCipherSuite;
|
||||
use crate::tls12::ConnectionSecrets;
|
||||
use crate::vecbuf::ChunkVecBuffer;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::collections::VecDeque;
|
||||
use std::convert::TryFrom;
|
||||
use std::io;
|
||||
use std::mem;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::task;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio_util::sync::ReusableBoxFuture;
|
||||
|
||||
/// Values of this structure are returned from [`Connection::process_new_packets`]
|
||||
/// and tell the caller the current I/O state of the TLS connection.
|
||||
@@ -150,36 +156,10 @@ impl<'a> io::Read for Reader<'a> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal trait implemented by the [`ServerConnection`]/[`ClientConnection`]
|
||||
/// allowing them to be the subject of a [`Writer`].
|
||||
pub trait PlaintextSink {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize>;
|
||||
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize>;
|
||||
fn flush(&mut self) -> io::Result<()>;
|
||||
}
|
||||
|
||||
impl<T> PlaintextSink for ConnectionCommon<T> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
Ok(self.send_some_plaintext(buf))
|
||||
}
|
||||
|
||||
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
|
||||
let mut sz = 0;
|
||||
for buf in bufs {
|
||||
sz += self.send_some_plaintext(buf);
|
||||
}
|
||||
Ok(sz)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A structure that implements [`std::io::Write`] for writing plaintext.
|
||||
/// A structure that implements [`tokio::io::AsyncWrite`] for writing plaintext.
|
||||
pub struct Writer<'a> {
|
||||
sink: &'a mut dyn PlaintextSink,
|
||||
con: &'a mut ConnectionCommon,
|
||||
fut: Option<ReusableBoxFuture<'a, usize>>,
|
||||
}
|
||||
|
||||
impl<'a> Writer<'a> {
|
||||
@@ -188,12 +168,12 @@ impl<'a> Writer<'a> {
|
||||
/// This is not an external interface. Get one of these objects
|
||||
/// from [`Connection::writer`].
|
||||
#[doc(hidden)]
|
||||
pub fn new(sink: &'a mut dyn PlaintextSink) -> Writer<'a> {
|
||||
Writer { sink }
|
||||
pub fn new(con: &'a mut ConnectionCommon) -> Writer<'a> {
|
||||
Writer { con, fut: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> io::Write for Writer<'a> {
|
||||
impl<'a> AsyncWrite for Writer<'a> {
|
||||
/// Send the plaintext `buf` to the peer, encrypting
|
||||
/// and authenticating it. Once this function succeeds
|
||||
/// you should call [`CommonState::write_tls`] which will output the
|
||||
@@ -203,16 +183,33 @@ impl<'a> io::Write for Writer<'a> {
|
||||
/// TLS handshake completes, and sends it as soon
|
||||
/// as it can. See [`CommonState::set_buffer_limit`] to control
|
||||
/// the size of this buffer.
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.sink.write(buf)
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
let fut = match self.fut.as_mut() {
|
||||
Some(fut) => fut,
|
||||
None => {
|
||||
let fut = self.con.send_some_plaintext(buf);
|
||||
self.fut.get_or_insert(ReusableBoxFuture::new(fut))
|
||||
}
|
||||
};
|
||||
match fut.poll(cx) {
|
||||
task::Poll::Ready(sz) => task::Poll::Ready(Ok(sz)),
|
||||
task::Poll::Pending => task::Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
|
||||
self.sink.write_vectored(bufs)
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
|
||||
task::Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.sink.flush()
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> task::Poll<Result<(), io::Error>> {
|
||||
task::Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,16 +252,20 @@ enum Limit {
|
||||
}
|
||||
|
||||
/// Interface shared by client and server connections.
|
||||
pub struct ConnectionCommon<Data> {
|
||||
state: Result<Box<dyn State<Data>>, Error>,
|
||||
pub(crate) data: Data,
|
||||
pub struct ConnectionCommon {
|
||||
state: Result<Box<dyn State<ClientConnectionData>>, Error>,
|
||||
pub(crate) data: ClientConnectionData,
|
||||
pub(crate) common_state: CommonState,
|
||||
message_deframer: MessageDeframer,
|
||||
handshake_joiner: HandshakeJoiner,
|
||||
}
|
||||
|
||||
impl<Data> ConnectionCommon<Data> {
|
||||
pub(crate) fn new(state: Box<dyn State<Data>>, data: Data, common_state: CommonState) -> Self {
|
||||
impl ConnectionCommon {
|
||||
pub(crate) fn new(
|
||||
state: Box<dyn State<ClientConnectionData>>,
|
||||
data: ClientConnectionData,
|
||||
common_state: CommonState,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: Ok(state),
|
||||
data,
|
||||
@@ -320,7 +321,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
/// [`write_tls`]: CommonState::write_tls
|
||||
/// [`read_tls`]: ConnectionCommon::read_tls
|
||||
/// [`process_new_packets`]: ConnectionCommon::process_new_packets
|
||||
pub fn complete_io<T>(&mut self, io: &mut T) -> Result<(usize, usize), io::Error>
|
||||
pub async fn complete_io<T>(&mut self, io: &mut T) -> Result<(usize, usize), io::Error>
|
||||
where
|
||||
Self: Sized,
|
||||
T: io::Read + io::Write,
|
||||
@@ -346,7 +347,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
}
|
||||
}
|
||||
|
||||
match self.process_new_packets() {
|
||||
match self.process_new_packets().await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
// In case we have an alert to send describing this error,
|
||||
@@ -396,15 +397,15 @@ impl<Data> ConnectionCommon<Data> {
|
||||
Ok(self.handshake_joiner.frames.pop_front())
|
||||
}
|
||||
|
||||
pub(crate) fn replace_state(&mut self, new: Box<dyn State<Data>>) {
|
||||
pub(crate) fn replace_state(&mut self, new: Box<dyn State<ClientConnectionData>>) {
|
||||
self.state = Ok(new);
|
||||
}
|
||||
|
||||
fn process_msg(
|
||||
async fn process_msg(
|
||||
&mut self,
|
||||
msg: OpaqueMessage,
|
||||
state: Box<dyn State<Data>>,
|
||||
) -> Result<Box<dyn State<Data>>, Error> {
|
||||
state: Box<dyn State<ClientConnectionData>>,
|
||||
) -> Result<Box<dyn State<ClientConnectionData>>, Error> {
|
||||
// Drop CCS messages during handshake in TLS1.3
|
||||
if msg.typ == ContentType::ChangeCipherSpec
|
||||
&& !self.common_state.may_receive_application_data
|
||||
@@ -430,7 +431,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
|
||||
// Decrypt if demanded by current state.
|
||||
let msg = match self.common_state.record_layer.is_decrypting() {
|
||||
true => match self.common_state.decrypt_incoming(msg) {
|
||||
true => match self.common_state.decrypt_incoming(msg).await {
|
||||
Ok(None) => {
|
||||
// message dropped
|
||||
return Ok(state);
|
||||
@@ -454,7 +455,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
Error::CorruptMessagePayload(ContentType::Handshake)
|
||||
})?;
|
||||
return self.process_new_handshake_messages(state);
|
||||
return self.process_new_handshake_messages(state).await;
|
||||
}
|
||||
|
||||
// Now we can fully parse the message payload.
|
||||
@@ -468,6 +469,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
|
||||
self.common_state
|
||||
.process_main_protocol(msg, state, &mut self.data)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Processes any new packets read by a previous call to
|
||||
@@ -488,7 +490,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
///
|
||||
/// [`read_tls`]: Connection::read_tls
|
||||
/// [`process_new_packets`]: Connection::process_new_packets
|
||||
pub fn process_new_packets(&mut self) -> Result<IoState, Error> {
|
||||
pub async fn process_new_packets(&mut self) -> Result<IoState, Error> {
|
||||
let mut state = match mem::replace(&mut self.state, Err(Error::HandshakeNotComplete)) {
|
||||
Ok(state) => state,
|
||||
Err(e) => {
|
||||
@@ -502,7 +504,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
}
|
||||
|
||||
while let Some(msg) = self.message_deframer.frames.pop_front() {
|
||||
match self.process_msg(msg, state) {
|
||||
match self.process_msg(msg, state).await {
|
||||
Ok(new) => state = new,
|
||||
Err(e) => {
|
||||
self.state = Err(e.clone());
|
||||
@@ -515,25 +517,26 @@ impl<Data> ConnectionCommon<Data> {
|
||||
Ok(self.common_state.current_io_state())
|
||||
}
|
||||
|
||||
fn process_new_handshake_messages(
|
||||
async fn process_new_handshake_messages(
|
||||
&mut self,
|
||||
mut state: Box<dyn State<Data>>,
|
||||
) -> Result<Box<dyn State<Data>>, Error> {
|
||||
mut state: Box<dyn State<ClientConnectionData>>,
|
||||
) -> Result<Box<dyn State<ClientConnectionData>>, Error> {
|
||||
self.common_state.aligned_handshake = self.handshake_joiner.is_empty();
|
||||
while let Some(msg) = self.handshake_joiner.frames.pop_front() {
|
||||
state = self
|
||||
.common_state
|
||||
.process_main_protocol(msg, state, &mut self.data)?;
|
||||
.process_main_protocol(msg, state, &mut self.data)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
pub(crate) fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
|
||||
pub(crate) async fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
|
||||
if let Ok(st) = &mut self.state {
|
||||
st.perhaps_write_key_update(&mut self.common_state);
|
||||
}
|
||||
self.common_state.send_some_plaintext(buf)
|
||||
self.common_state.send_some_plaintext(buf).await
|
||||
}
|
||||
|
||||
/// Read TLS content from `rd`. This method does internal
|
||||
@@ -585,7 +588,7 @@ impl<Data> ConnectionCommon<Data> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for ConnectionCommon<T> {
|
||||
impl Deref for ConnectionCommon {
|
||||
type Target = CommonState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
@@ -593,7 +596,7 @@ impl<T> Deref for ConnectionCommon<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for ConnectionCommon<T> {
|
||||
impl DerefMut for ConnectionCommon {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.common_state
|
||||
}
|
||||
@@ -713,12 +716,12 @@ impl CommonState {
|
||||
matches!(self.negotiated_version, Some(ProtocolVersion::TLSv1_3))
|
||||
}
|
||||
|
||||
fn process_main_protocol<Data>(
|
||||
async fn process_main_protocol(
|
||||
&mut self,
|
||||
msg: Message,
|
||||
mut state: Box<dyn State<Data>>,
|
||||
data: &mut Data,
|
||||
) -> Result<Box<dyn State<Data>>, Error> {
|
||||
mut state: Box<dyn State<ClientConnectionData>>,
|
||||
data: &mut ClientConnectionData,
|
||||
) -> Result<Box<dyn State<ClientConnectionData>>, Error> {
|
||||
// For TLS1.2, outside of the handshake, send rejection alerts for
|
||||
// renegotiation requests. These can occur any time.
|
||||
if self.may_receive_application_data && !self.is_tls13() {
|
||||
@@ -733,7 +736,7 @@ impl CommonState {
|
||||
}
|
||||
|
||||
let mut cx = Context { common: self, data };
|
||||
match state.handle(&mut cx, msg) {
|
||||
match state.handle(&mut cx, msg).await {
|
||||
Ok(next) => {
|
||||
state = next;
|
||||
Ok(state)
|
||||
@@ -752,11 +755,11 @@ impl CommonState {
|
||||
///
|
||||
/// If internal buffers are too small, this function will not accept
|
||||
/// all the data.
|
||||
pub(crate) fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
|
||||
self.send_plain(data, Limit::Yes)
|
||||
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
|
||||
self.send_plain(data, Limit::Yes).await
|
||||
}
|
||||
|
||||
pub(crate) fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
|
||||
pub(crate) async fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
|
||||
debug_assert!(self.early_traffic);
|
||||
debug_assert!(self.record_layer.is_encrypting());
|
||||
|
||||
@@ -765,7 +768,7 @@ impl CommonState {
|
||||
return 0;
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, Limit::Yes)
|
||||
self.send_appdata_encrypt(data, Limit::Yes).await
|
||||
}
|
||||
|
||||
// Changing the keys must not span any fragmented handshake
|
||||
@@ -788,7 +791,7 @@ impl CommonState {
|
||||
Error::PeerMisbehavedError(why.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn decrypt_incoming(
|
||||
pub(crate) async fn decrypt_incoming(
|
||||
&mut self,
|
||||
encr: OpaqueMessage,
|
||||
) -> Result<Option<PlainMessage>, Error> {
|
||||
@@ -797,7 +800,7 @@ impl CommonState {
|
||||
}
|
||||
|
||||
let encrypted_len = encr.payload.0.len();
|
||||
let plain = self.record_layer.decrypt_incoming(encr);
|
||||
let plain = self.record_layer.decrypt_incoming(encr).await;
|
||||
|
||||
match plain {
|
||||
Err(Error::PeerSentOversizedRecord) => {
|
||||
@@ -819,17 +822,17 @@ impl CommonState {
|
||||
|
||||
/// Fragment `m`, encrypt the fragments, and then queue
|
||||
/// the encrypted fragments for sending.
|
||||
pub(crate) fn send_msg_encrypt(&mut self, m: PlainMessage) {
|
||||
pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) {
|
||||
let mut plain_messages = VecDeque::new();
|
||||
self.message_fragmenter.fragment(m, &mut plain_messages);
|
||||
|
||||
for m in plain_messages {
|
||||
self.send_single_fragment(m.borrow());
|
||||
self.send_single_fragment(m.borrow()).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Like send_msg_encrypt, but operate on an appdata directly.
|
||||
fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
|
||||
async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
|
||||
// Here, the limit on sendable_tls applies to encrypted data,
|
||||
// but we're respecting it for plaintext data -- so we'll
|
||||
// be out by whatever the cipher+record overhead is. That's a
|
||||
@@ -848,13 +851,13 @@ impl CommonState {
|
||||
);
|
||||
|
||||
for m in plain_messages {
|
||||
self.send_single_fragment(m);
|
||||
self.send_single_fragment(m).await;
|
||||
}
|
||||
|
||||
len
|
||||
}
|
||||
|
||||
fn send_single_fragment(&mut self, m: BorrowedPlainMessage) {
|
||||
async fn send_single_fragment<'a>(&mut self, m: BorrowedPlainMessage<'a>) {
|
||||
// Close connection once we start to run out of
|
||||
// sequence space.
|
||||
if self.record_layer.wants_close_before_encrypt() {
|
||||
@@ -867,7 +870,7 @@ impl CommonState {
|
||||
return;
|
||||
}
|
||||
|
||||
let em = self.record_layer.encrypt_outgoing(m);
|
||||
let em = self.record_layer.encrypt_outgoing(m).await;
|
||||
self.queue_tls_message(em);
|
||||
}
|
||||
|
||||
@@ -887,7 +890,7 @@ impl CommonState {
|
||||
///
|
||||
/// Returns the number of bytes written from `data`: this might
|
||||
/// be less than `data.len()` if buffer limits were exceeded.
|
||||
fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
|
||||
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
|
||||
if !self.may_send_application_data {
|
||||
// If we haven't completed handshaking, buffer
|
||||
// plaintext to send once we do.
|
||||
@@ -905,7 +908,7 @@ impl CommonState {
|
||||
return 0;
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, limit)
|
||||
self.send_appdata_encrypt(data, limit).await
|
||||
}
|
||||
|
||||
pub(crate) fn start_outgoing_traffic(&mut self) {
|
||||
@@ -1092,12 +1095,13 @@ impl CommonState {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait State<Data>: Send + Sync {
|
||||
fn handle(
|
||||
#[async_trait]
|
||||
pub(crate) trait State<ClientConnectionData>: Send + Sync {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
cx: &mut Context<'_, Data>,
|
||||
cx: &mut Context<'_>,
|
||||
message: Message,
|
||||
) -> Result<Box<dyn State<Data>>, Error>;
|
||||
) -> Result<Box<dyn State<ClientConnectionData>>, Error>;
|
||||
|
||||
fn export_keying_material(
|
||||
&self,
|
||||
@@ -1111,9 +1115,9 @@ pub(crate) trait State<Data>: Send + Sync {
|
||||
fn perhaps_write_key_update(&mut self, _cx: &mut CommonState) {}
|
||||
}
|
||||
|
||||
pub(crate) struct Context<'a, Data> {
|
||||
pub(crate) struct Context<'a> {
|
||||
pub(crate) common: &'a mut CommonState,
|
||||
pub(crate) data: &'a mut Data,
|
||||
pub(crate) data: &'a mut ClientConnectionData,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
|
||||
@@ -318,7 +318,7 @@ mod hash_hs;
|
||||
mod limited_cache;
|
||||
mod rand;
|
||||
mod record_layer;
|
||||
mod stream;
|
||||
//mod stream;
|
||||
#[cfg(feature = "tls12")]
|
||||
mod tls12;
|
||||
mod tls13;
|
||||
@@ -367,7 +367,7 @@ pub use crate::msgs::enums::CipherSuite;
|
||||
pub use crate::msgs::enums::ProtocolVersion;
|
||||
pub use crate::msgs::enums::SignatureScheme;
|
||||
pub use crate::msgs::handshake::DistinguishedNames;
|
||||
pub use crate::stream::{Stream, StreamOwned};
|
||||
//pub use crate::stream::{Stream, StreamOwned};
|
||||
pub use crate::suites::{
|
||||
BulkAlgorithm, SupportedCipherSuite, ALL_CIPHER_SUITES, DEFAULT_CIPHER_SUITES,
|
||||
};
|
||||
|
||||
@@ -156,12 +156,13 @@ impl RecordLayer {
|
||||
/// `encr` is a decoded message allegedly received from the peer.
|
||||
/// If it can be decrypted, its decryption is returned. Otherwise,
|
||||
/// an error is returned.
|
||||
pub(crate) fn decrypt_incoming(&mut self, encr: OpaqueMessage) -> Result<PlainMessage, Error> {
|
||||
pub(crate) async fn decrypt_incoming(
|
||||
&mut self,
|
||||
encr: OpaqueMessage,
|
||||
) -> Result<PlainMessage, Error> {
|
||||
debug_assert!(self.is_decrypting());
|
||||
let seq = self.read_seq;
|
||||
let msg = self
|
||||
.message_decrypter
|
||||
.decrypt(encr, seq)?;
|
||||
let msg = self.message_decrypter.decrypt(encr, seq)?;
|
||||
self.read_seq += 1;
|
||||
Ok(msg)
|
||||
}
|
||||
@@ -170,13 +171,14 @@ impl RecordLayer {
|
||||
///
|
||||
/// `plain` is a TLS message we'd like to send. This function
|
||||
/// panics if the requisite keying material hasn't been established yet.
|
||||
pub(crate) fn encrypt_outgoing(&mut self, plain: BorrowedPlainMessage) -> OpaqueMessage {
|
||||
pub(crate) async fn encrypt_outgoing<'a>(
|
||||
&mut self,
|
||||
plain: BorrowedPlainMessage<'a>,
|
||||
) -> OpaqueMessage {
|
||||
debug_assert!(self.encrypt_state == DirectionState::Active);
|
||||
assert!(!self.encrypt_exhausted());
|
||||
let seq = self.write_seq;
|
||||
self.write_seq += 1;
|
||||
self.message_encrypter
|
||||
.encrypt(plain, seq)
|
||||
.unwrap()
|
||||
self.message_encrypter.encrypt(plain, seq).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user