more work on async refactor

This commit is contained in:
sinuio
2022-05-19 19:26:59 -07:00
parent 858b35ec60
commit 031acbb642
7 changed files with 362 additions and 359 deletions

View File

@@ -371,46 +371,6 @@ impl EarlyData {
}
}
/// Stub that implements io::Write and dispatches to `write_early_data`.
pub struct WriteEarlyData<'a> {
sess: &'a mut ClientConnection,
fut: Option<futures::future::BoxFuture<'a, io::Result<usize>>>,
}
impl<'a> WriteEarlyData<'a> {
fn new(sess: &'a mut ClientConnection) -> WriteEarlyData<'a> {
WriteEarlyData { sess, fut: None }
}
/// How many bytes you may send. Writes will become short
/// once this reaches zero.
pub fn bytes_left(&self) -> usize {
self.sess.inner.data.early_data.bytes_left()
}
}
impl<'a> AsyncWrite for WriteEarlyData<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let fut = match self.fut {
Some(fut) => fut,
None => Box::pin(self.sess.write_early_data(buf)),
};
fut.as_mut().poll(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
/// This represents a single TLS client connection.
pub struct ClientConnection {
inner: ConnectionCommon,
@@ -451,31 +411,31 @@ impl ClientConnection {
Ok(Self { inner })
}
/// Returns an `io::Write` implementer you can write bytes to
/// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server.
///
/// This returns None in many circumstances when the capability to
/// send early data is not available, including but not limited to:
///
/// - The server hasn't been talked to previously.
/// - The server does not support resumption.
/// - The server does not support early data.
/// - The resumption data for the server has expired.
///
/// The server specifies a maximum amount of early data. You can
/// learn this limit through the returned object, and writes through
/// it will process only this many bytes.
///
/// The server can choose not to accept any sent early data --
/// in this case the data is lost but the connection continues. You
/// can tell this happened using `is_early_data_accepted`.
pub fn early_data(&mut self) -> Option<WriteEarlyData> {
if self.inner.data.early_data.is_enabled() {
Some(WriteEarlyData::new(self))
} else {
None
}
}
// /// Returns an `io::Write` implementer you can write bytes to
// /// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server.
// ///
// /// This returns None in many circumstances when the capability to
// /// send early data is not available, including but not limited to:
// ///
// /// - The server hasn't been talked to previously.
// /// - The server does not support resumption.
// /// - The server does not support early data.
// /// - The resumption data for the server has expired.
// ///
// /// The server specifies a maximum amount of early data. You can
// /// learn this limit through the returned object, and writes through
// /// it will process only this many bytes.
// ///
// /// The server can choose not to accept any sent early data --
// /// in this case the data is lost but the connection continues. You
// /// can tell this happened using `is_early_data_accepted`.
// pub fn early_data(&mut self) -> Option<WriteEarlyData> {
// if self.inner.data.early_data.is_enabled() {
// Some(WriteEarlyData::new(self))
// } else {
// None
// }
// }
/// Returns True if the server signalled it will process early data.
///

View File

@@ -256,27 +256,29 @@ async fn emit_client_hello_for_retry(
&& resume_version == ProtocolVersion::TLSv1_3
&& !ticket.is_empty()
{
resuming_session
.as_ref()
.and_then(|resuming| match (suite, resuming.tls13()) {
(Some(suite), Some(resuming)) => {
suite.tls13()?.can_resume_from(resuming.suite())?;
Some(resuming)
}
(None, Some(resuming)) => Some(resuming),
_ => None,
})
.map(|resuming| {
tls13::prepare_resumption(
&config,
cx,
ticket,
&resuming,
&mut exts,
retryreq.is_some(),
);
resuming
})
let resuming =
resuming_session
.as_ref()
.and_then(|resuming| match (suite, resuming.tls13()) {
(Some(suite), Some(resuming)) => {
suite.tls13()?.can_resume_from(resuming.suite())?;
Some(resuming)
}
(None, Some(resuming)) => Some(resuming),
_ => None,
});
if let Some(ref resuming) = resuming {
tls13::prepare_resumption(
&config,
cx,
ticket,
&resuming,
&mut exts,
retryreq.is_some(),
)
.await;
}
resuming
} else if config.enable_tickets {
// If we have a ticket, include it. Otherwise, request one.
if ticket.is_empty() {
@@ -333,31 +335,33 @@ async fn emit_client_hello_for_retry(
if retryreq.is_some() {
// send dummy CCS to fool middleboxes prior
// to second client hello
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common);
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
}
trace!("Sending ClientHello {:#?}", ch);
transcript_buffer.add_message(&ch);
cx.common.send_msg(ch, false);
cx.common.send_msg(ch, false).await;
// Calculate the hash of ClientHello and use it to derive EarlyTrafficSecret
let early_key_schedule = early_key_schedule.map(|(resuming_suite, schedule)| {
if !cx.data.early_data.is_enabled() {
return schedule;
let early_key_schedule = match early_key_schedule {
Some((resuming_suite, schedule)) => {
if cx.data.early_data.is_enabled() {
tls13::derive_early_traffic_secret(
&*config.key_log,
cx,
resuming_suite,
&schedule,
&mut sent_tls13_fake_ccs,
&transcript_buffer,
&random.0,
)
.await;
}
Some(schedule)
}
tls13::derive_early_traffic_secret(
&*config.key_log,
cx,
resuming_suite,
&schedule,
&mut sent_tls13_fake_ccs,
&transcript_buffer,
&random.0,
);
schedule
});
None => None,
};
let next = ExpectServerHello {
config,
@@ -381,7 +385,7 @@ async fn emit_client_hello_for_retry(
}
}
pub(super) fn process_alpn_protocol(
pub(super) async fn process_alpn_protocol(
common: &mut CommonState,
config: &ClientConfig,
proto: Option<&[u8]>,
@@ -390,7 +394,9 @@ pub(super) fn process_alpn_protocol(
if let Some(alpn_protocol) = &common.alpn_protocol {
if !config.alpn_protocols.contains(alpn_protocol) {
return Err(common.illegal_param("server sent non-offered ALPN protocol"));
return Err(common
.illegal_param("server sent non-offered ALPN protocol")
.await);
}
}
@@ -441,14 +447,16 @@ impl State<ClientConnectionData> for ExpectServerHello {
if server_hello.get_supported_versions().is_some() {
return Err(cx
.common
.illegal_param("server chose v1.2 using v1.3 extension"));
.illegal_param("server chose v1.2 using v1.3 extension")
.await);
}
TLSv1_2
}
_ => {
cx.common
.send_fatal_alert(AlertDescription::ProtocolVersion);
.send_fatal_alert(AlertDescription::ProtocolVersion)
.await;
let msg = match server_version {
TLSv1_2 | TLSv1_3 => "server's TLS version is disabled in client",
_ => "server does not support TLS v1.2/v1.3",
@@ -458,11 +466,16 @@ impl State<ClientConnectionData> for ExpectServerHello {
};
if server_hello.compression_method != Compression::Null {
return Err(cx.common.illegal_param("server chose non-Null compression"));
return Err(cx
.common
.illegal_param("server chose non-Null compression")
.await);
}
if server_hello.has_duplicate_extension() {
cx.common.send_fatal_alert(AlertDescription::DecodeError);
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
return Err(Error::PeerMisbehavedError(
"server sent duplicate extensions".to_string(),
));
@@ -474,7 +487,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
.server_sent_unsolicited_extensions(&server_hello.extensions, &allowed_unsolicited)
{
cx.common
.send_fatal_alert(AlertDescription::UnsupportedExtension);
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
return Err(Error::PeerMisbehavedError(
"server sent unsolicited extension".to_string(),
));
@@ -484,7 +498,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
// Extract ALPN protocol
if !cx.common.is_tls13() {
process_alpn_protocol(cx.common, &self.config, server_hello.get_alpn_protocol())?;
process_alpn_protocol(cx.common, &self.config, server_hello.get_alpn_protocol())
.await?;
}
// If ECPointFormats extension is supplied by the server, it must contain
@@ -492,33 +507,39 @@ impl State<ClientConnectionData> for ExpectServerHello {
if let Some(point_fmts) = server_hello.get_ecpoints_extension() {
if !point_fmts.contains(&ECPointFormat::Uncompressed) {
cx.common
.send_fatal_alert(AlertDescription::HandshakeFailure);
.send_fatal_alert(AlertDescription::HandshakeFailure)
.await;
return Err(Error::PeerMisbehavedError(
"server does not support uncompressed points".to_string(),
));
}
}
let suite = self
.config
.find_cipher_suite(server_hello.cipher_suite)
.ok_or_else(|| {
let suite = match self.config.find_cipher_suite(server_hello.cipher_suite) {
Some(suite) => suite,
None => {
cx.common
.send_fatal_alert(AlertDescription::HandshakeFailure);
Error::PeerMisbehavedError("server chose non-offered ciphersuite".to_string())
})?;
.send_fatal_alert(AlertDescription::HandshakeFailure)
.await;
return Err(Error::PeerMisbehavedError(
"server chose non-offered ciphersuite".to_string(),
));
}
};
if version != suite.version().version {
return Err(cx
.common
.illegal_param("server chose unusable ciphersuite for version"));
.illegal_param("server chose unusable ciphersuite for version")
.await);
}
match self.suite {
Some(prev_suite) if prev_suite != suite => {
return Err(cx
.common
.illegal_param("server varied selected ciphersuite"));
.illegal_param("server varied selected ciphersuite")
.await);
}
_ => {
debug!("Using ciphersuite {:?}", suite);
@@ -602,7 +623,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
)?;
trace!("Got HRR {:?}", hrr);
cx.common.check_aligned_handshake()?;
cx.common.check_aligned_handshake().await?;
let cookie = hrr.get_cookie();
let req_group = hrr.get_requested_key_share_group();
@@ -615,7 +636,8 @@ impl ExpectServerHelloOrHelloRetryRequest {
if cookie.is_none() && req_group == Some(offered_key_share.group()) {
return Err(cx
.common
.illegal_param("server requested hrr with our group"));
.illegal_param("server requested hrr with our group")
.await);
}
// Or has an empty cookie.
@@ -623,14 +645,16 @@ impl ExpectServerHelloOrHelloRetryRequest {
if cookie.0.is_empty() {
return Err(cx
.common
.illegal_param("server requested hrr with empty cookie"));
.illegal_param("server requested hrr with empty cookie")
.await);
}
}
// Or has something unrecognised
if hrr.has_unknown_extension() {
cx.common
.send_fatal_alert(AlertDescription::UnsupportedExtension);
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
return Err(Error::PeerIncompatibleError(
"server sent hrr with unhandled extension".to_string(),
));
@@ -640,14 +664,16 @@ impl ExpectServerHelloOrHelloRetryRequest {
if hrr.has_duplicate_extension() {
return Err(cx
.common
.illegal_param("server send duplicate hrr extensions"));
.illegal_param("server send duplicate hrr extensions")
.await);
}
// Or asks us to change nothing.
if cookie.is_none() && req_group.is_none() {
return Err(cx
.common
.illegal_param("server requested hrr with no changes"));
.illegal_param("server requested hrr with no changes")
.await);
}
// Or asks us to talk a protocol we didn't offer, or doesn't support HRR at all.
@@ -658,7 +684,8 @@ impl ExpectServerHelloOrHelloRetryRequest {
_ => {
return Err(cx
.common
.illegal_param("server requested unsupported version in hrr"));
.illegal_param("server requested unsupported version in hrr")
.await);
}
}
@@ -669,7 +696,8 @@ impl ExpectServerHelloOrHelloRetryRequest {
None => {
return Err(cx
.common
.illegal_param("server requested unsupported cs in hrr"));
.illegal_param("server requested unsupported cs in hrr")
.await);
}
};
@@ -690,11 +718,15 @@ impl ExpectServerHelloOrHelloRetryRequest {
let key_share = match req_group {
Some(group) if group != offered_key_share.group() => {
let group = kx::KeyExchange::choose(group, &self.next.config.kx_groups)
.ok_or_else(|| {
cx.common
let group = match kx::KeyExchange::choose(group, &self.next.config.kx_groups) {
Some(group) => group,
None => {
return Err(cx
.common
.illegal_param("server requested hrr with bad group")
})?;
.await)
}
};
kx::KeyExchange::start(group).ok_or(Error::FailedToGetRandomBytes)?
}
_ => offered_key_share,
@@ -742,16 +774,20 @@ impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
}
}
pub(super) fn send_cert_error_alert(common: &mut CommonState, err: Error) -> Error {
pub(super) async fn send_cert_error_alert(common: &mut CommonState, err: Error) -> Error {
match err {
Error::InvalidCertificateEncoding => {
common.send_fatal_alert(AlertDescription::DecodeError);
common.send_fatal_alert(AlertDescription::DecodeError).await;
}
Error::PeerMisbehavedError(_) => {
common.send_fatal_alert(AlertDescription::IllegalParameter);
common
.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
}
_ => {
common.send_fatal_alert(AlertDescription::BadCertificate);
common
.send_fatal_alert(AlertDescription::BadCertificate)
.await;
}
};

View File

@@ -67,7 +67,8 @@ mod server_hello {
if tls13_supported && has_downgrade_marker {
return Err(cx
.common
.illegal_param("downgrade to TLS1.2 when TLS1.3 is supported"));
.illegal_param("downgrade to TLS1.2 when TLS1.3 is supported")
.await);
}
// Doing EMS?
@@ -408,10 +409,15 @@ impl State<ClientConnectionData> for ExpectServerKx {
)?;
self.transcript.add_message(&m);
let ecdhe = opaque_kx.unwrap_given_kxa(&self.suite.kx).ok_or_else(|| {
cx.common.send_fatal_alert(AlertDescription::DecodeError);
Error::CorruptMessagePayload(ContentType::Handshake)
})?;
let ecdhe = match opaque_kx.unwrap_given_kxa(&self.suite.kx) {
Some(ecdhe) => ecdhe,
None => {
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
};
// Save the signature and signed parameters for later verification.
let mut kx_params = Vec::new();
@@ -439,7 +445,7 @@ impl State<ClientConnectionData> for ExpectServerKx {
}
}
fn emit_certificate(
async fn emit_certificate(
transcript: &mut HandshakeHash,
cert_chain: CertificatePayload,
common: &mut CommonState,
@@ -453,10 +459,14 @@ fn emit_certificate(
};
transcript.add_message(&cert);
common.send_msg(cert, false);
common.send_msg(cert, false).await;
}
fn emit_clientkx(transcript: &mut HandshakeHash, common: &mut CommonState, pubkey: &PublicKey) {
async fn emit_clientkx(
transcript: &mut HandshakeHash,
common: &mut CommonState,
pubkey: &PublicKey,
) {
let mut buf = Vec::new();
let ecpoint = PayloadU8::new(Vec::from(pubkey.as_ref()));
ecpoint.encode(&mut buf);
@@ -471,10 +481,10 @@ fn emit_clientkx(transcript: &mut HandshakeHash, common: &mut CommonState, pubke
};
transcript.add_message(&ckx);
common.send_msg(ckx, false);
common.send_msg(ckx, false).await;
}
fn emit_certverify(
async fn emit_certverify(
transcript: &mut HandshakeHash,
signer: &dyn Signer,
common: &mut CommonState,
@@ -496,20 +506,20 @@ fn emit_certverify(
};
transcript.add_message(&m);
common.send_msg(m, false);
common.send_msg(m, false).await;
Ok(())
}
fn emit_ccs(common: &mut CommonState) {
async fn emit_ccs(common: &mut CommonState) {
let ccs = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
};
common.send_msg(ccs, false);
common.send_msg(ccs, false).await;
}
fn emit_finished(
async fn emit_finished(
secrets: &ConnectionSecrets,
transcript: &mut HandshakeHash,
common: &mut CommonState,
@@ -527,7 +537,7 @@ fn emit_finished(
};
transcript.add_message(&f);
common.send_msg(f, true);
common.send_msg(f, true).await;
}
struct ServerKxDetails {
@@ -712,7 +722,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
let mut st = *self;
st.transcript.add_message(&m);
cx.common.check_aligned_handshake()?;
cx.common.check_aligned_handshake().await?;
trace!("Server cert is {:?}", st.server_cert.cert_chain);
debug!("Server DNS name is {:?}", st.server_name);
@@ -738,18 +748,17 @@ impl State<ClientConnectionData> for ExpectServerDone {
.split_first()
.ok_or(Error::NoCertificatesPresented)?;
let now = std::time::SystemTime::now();
let cert_verified = st
.config
.verifier
.verify_server_cert(
end_entity,
intermediates,
&st.server_name,
&mut st.server_cert.scts(),
&st.server_cert.ocsp_response,
now,
)
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?;
let cert_verified = match st.config.verifier.verify_server_cert(
end_entity,
intermediates,
&st.server_name,
&mut st.server_cert.scts(),
&st.server_cert.ocsp_response,
now,
) {
Ok(cert_verified) => cert_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
};
// 3.
// Build up the contents of the signed message.
@@ -772,10 +781,14 @@ impl State<ClientConnectionData> for ExpectServerDone {
return Err(Error::PeerMisbehavedError(error_message));
}
st.config
.verifier
.verify_tls12_signature(&message, &st.server_cert.cert_chain[0], sig)
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?
match st.config.verifier.verify_tls12_signature(
&message,
&st.server_cert.cert_chain[0],
sig,
) {
Ok(sig_verified) => sig_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
}
};
cx.common.peer_certificates = Some(st.server_cert.cert_chain);
@@ -785,7 +798,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
ClientAuthDetails::Empty { .. } => Vec::new(),
ClientAuthDetails::Verify { certkey, .. } => certkey.cert.clone(),
};
emit_certificate(&mut st.transcript, certs, cx.common);
emit_certificate(&mut st.transcript, certs, cx.common).await;
}
// 5a.
@@ -800,17 +813,17 @@ impl State<ClientConnectionData> for ExpectServerDone {
// 5b.
let mut transcript = st.transcript;
emit_clientkx(&mut transcript, cx.common, &kx.pubkey);
emit_clientkx(&mut transcript, cx.common, &kx.pubkey).await;
// nb. EMS handshake hash only runs up to ClientKeyExchange.
let ems_seed = st.using_ems.then(|| transcript.get_current_hash());
// 5c.
if let Some(ClientAuthDetails::Verify { signer, .. }) = &st.client_auth {
emit_certverify(&mut transcript, signer.as_ref(), cx.common)?;
emit_certverify(&mut transcript, signer.as_ref(), cx.common).await?;
}
// 5d.
emit_ccs(cx.common);
emit_ccs(cx.common).await;
// 5e. Now commit secrets.
let secrets = ConnectionSecrets::from_key_exchange(
@@ -830,7 +843,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
cx.common.record_layer.start_encrypting();
// 6.
emit_finished(&secrets, &mut transcript, cx.common);
emit_finished(&secrets, &mut transcript, cx.common).await;
if st.must_issue_new_ticket {
Ok(Box::new(ExpectNewTicket {
@@ -940,7 +953,7 @@ impl State<ClientConnectionData> for ExpectCcs {
}
// CCS should not be received interleaved with fragmented handshake-level
// message.
cx.common.check_aligned_handshake()?;
cx.common.check_aligned_handshake().await?;
// nb. msgs layer validates trivial contents of CCS
cx.common.record_layer.start_decrypting();
@@ -1040,7 +1053,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let finished =
require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
cx.common.check_aligned_handshake()?;
cx.common.check_aligned_handshake().await?;
// Work out what verify_data we expect.
let vh = st.transcript.get_current_hash();
@@ -1049,12 +1062,15 @@ impl State<ClientConnectionData> for ExpectFinished {
// Constant-time verification of this is relatively unimportant: they only
// get one chance. But it can't hurt.
let _fin_verified =
constant_time::verify_slices_are_equal(&expect_verify_data, &finished.0)
.map_err(|_| {
cx.common.send_fatal_alert(AlertDescription::DecryptError);
Error::DecryptError
})
.map(|_| verify::FinishedMessageVerified::assertion())?;
match constant_time::verify_slices_are_equal(&expect_verify_data, &finished.0) {
Ok(()) => verify::FinishedMessageVerified::assertion(),
Err(_) => {
cx.common
.send_fatal_alert(AlertDescription::DecryptError)
.await;
return Err(Error::DecryptError);
}
};
// Hash this message too.
st.transcript.add_message(&m);
@@ -1062,9 +1078,9 @@ impl State<ClientConnectionData> for ExpectFinished {
st.save_session(cx);
if st.resuming {
emit_ccs(cx.common);
emit_ccs(cx.common).await;
cx.common.record_layer.start_encrypting();
emit_finished(&st.secrets, &mut st.transcript, cx.common);
emit_finished(&st.secrets, &mut st.transcript, cx.common).await;
}
cx.common.start_traffic();

View File

@@ -35,6 +35,7 @@ use crate::client::common::{ClientAuthDetails, ClientHelloDetails};
use crate::client::{hs, ClientConfig, ServerName, StoresClientSessions};
use crate::ticketer::TimeBase;
use futures::Future;
use ring::constant_time;
use crate::sign::{CertifiedKey, Signer};
@@ -71,16 +72,20 @@ pub(super) async fn handle_server_hello(
our_key_share: kx::KeyExchange,
mut sent_tls13_fake_ccs: bool,
) -> hs::NextStateOrError {
validate_server_hello(cx.common, server_hello)?;
validate_server_hello(cx.common, server_hello).await?;
let their_key_share = server_hello.get_key_share().ok_or_else(|| {
cx.common
.send_fatal_alert(AlertDescription::MissingExtension);
Error::PeerMisbehavedError("missing key share".to_string())
})?;
let their_key_share = match server_hello.get_key_share() {
Some(ks) => ks,
None => {
cx.common
.send_fatal_alert(AlertDescription::MissingExtension)
.await;
return Err(Error::PeerMisbehavedError("missing key share".to_string()));
}
};
if our_key_share.group() != their_key_share.group {
return Err(cx.common.illegal_param("wrong group for key share"));
return Err(cx.common.illegal_param("wrong group for key share").await);
}
let key_schedule_pre_handshake = if let (Some(selected_psk), Some(early_key_schedule)) =
@@ -92,7 +97,8 @@ pub(super) async fn handle_server_hello(
None => {
return Err(cx
.common
.illegal_param("server resuming incompatible suite"));
.illegal_param("server resuming incompatible suite")
.await);
}
};
@@ -101,11 +107,12 @@ pub(super) async fn handle_server_hello(
if cx.data.early_data.is_enabled() && resuming_suite != suite {
return Err(cx
.common
.illegal_param("server varied suite with early data"));
.illegal_param("server varied suite with early data")
.await);
}
if selected_psk != 0 {
return Err(cx.common.illegal_param("server selected invalid psk"));
return Err(cx.common.illegal_param("server selected invalid psk").await);
}
debug!("Resuming using PSK");
@@ -134,7 +141,7 @@ pub(super) async fn handle_server_hello(
// If we change keying when a subsequent handshake message is being joined,
// the two halves will have different record layer protections. Disallow this.
cx.common.check_aligned_handshake()?;
cx.common.check_aligned_handshake().await?;
let hash_at_client_recvd_server_hello = transcript.get_current_hash();
@@ -156,7 +163,7 @@ pub(super) async fn handle_server_hello(
.set_message_encrypter(suite.derive_encrypter(&client_key));
}
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common);
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
Ok(Box::new(ExpectEncryptedExtensions {
config,
@@ -170,13 +177,15 @@ pub(super) async fn handle_server_hello(
}))
}
fn validate_server_hello(
async fn validate_server_hello(
common: &mut CommonState,
server_hello: &ServerHelloPayload,
) -> Result<(), Error> {
for ext in &server_hello.extensions {
if !ALLOWED_PLAINTEXT_EXTS.contains(&ext.get_type()) {
common.send_fatal_alert(AlertDescription::UnsupportedExtension);
common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
return Err(Error::PeerMisbehavedError(
"server sent unexpected cleartext ext".to_string(),
));
@@ -239,7 +248,7 @@ pub(super) fn fill_in_psk_binder(
key_schedule
}
pub(super) fn prepare_resumption(
pub(super) async fn prepare_resumption(
config: &ClientConfig,
cx: &mut ClientContext<'_>,
ticket: Vec<u8>,
@@ -273,17 +282,17 @@ pub(super) fn prepare_resumption(
exts.push(ClientExtension::PresharedKey(psk_ext));
}
pub(super) fn derive_early_traffic_secret(
pub(super) async fn derive_early_traffic_secret(
key_log: &dyn KeyLog,
cx: &mut ClientContext<'_>,
resuming_suite: &'static Tls13CipherSuite,
resuming_suite: &Tls13CipherSuite,
early_key_schedule: &KeyScheduleEarly,
sent_tls13_fake_ccs: &mut bool,
transcript_buffer: &HandshakeHashBuffer,
client_random: &[u8; 32],
) {
// For middlebox compatibility
emit_fake_ccs(sent_tls13_fake_ccs, cx.common);
emit_fake_ccs(sent_tls13_fake_ccs, cx.common).await;
let client_hello_hash = transcript_buffer.get_hash_given(resuming_suite.hash_algorithm(), &[]);
let client_early_traffic_secret =
@@ -298,7 +307,7 @@ pub(super) fn derive_early_traffic_secret(
trace!("Starting early data traffic");
}
pub(super) fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonState) {
pub(super) async fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonState) {
if std::mem::replace(sent_tls13_fake_ccs, true) {
return;
}
@@ -307,23 +316,25 @@ pub(super) fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonS
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
};
common.send_msg(m, false);
common.send_msg(m, false).await;
}
fn validate_encrypted_extensions(
async fn validate_encrypted_extensions(
common: &mut CommonState,
hello: &ClientHelloDetails,
exts: &EncryptedExtensions,
) -> Result<(), Error> {
if exts.has_duplicate_extension() {
common.send_fatal_alert(AlertDescription::DecodeError);
common.send_fatal_alert(AlertDescription::DecodeError).await;
return Err(Error::PeerMisbehavedError(
"server sent duplicate encrypted extensions".to_string(),
));
}
if hello.server_sent_unsolicited_extensions(exts, &[]) {
common.send_fatal_alert(AlertDescription::UnsupportedExtension);
common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
let msg = "server sent unsolicited encrypted extension".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
@@ -332,7 +343,9 @@ fn validate_encrypted_extensions(
if ALLOWED_PLAINTEXT_EXTS.contains(&ext.get_type())
|| DISALLOWED_TLS13_EXTS.contains(&ext.get_type())
{
common.send_fatal_alert(AlertDescription::UnsupportedExtension);
common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
let msg = "server sent inappropriate encrypted extension".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
@@ -367,8 +380,8 @@ impl State<ClientConnectionData> for ExpectEncryptedExtensions {
debug!("TLS1.3 encrypted extensions: {:?}", exts);
self.transcript.add_message(&m);
validate_encrypted_extensions(cx.common, &self.hello, exts)?;
hs::process_alpn_protocol(cx.common, &self.config, exts.get_alpn_protocol())?;
validate_encrypted_extensions(cx.common, &self.hello, exts).await?;
hs::process_alpn_protocol(cx.common, &self.config, exts.get_alpn_protocol()).await?;
if let Some(resuming_session) = self.resuming_session {
let was_early_traffic = cx.common.early_traffic;
@@ -520,7 +533,9 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
// Must be empty during handshake.
if !certreq.context.0.is_empty() {
warn!("Server sent non-empty certreq context");
cx.common.send_fatal_alert(AlertDescription::DecodeError);
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
@@ -536,7 +551,8 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
if compat_sigschemes.is_empty() {
cx.common
.send_fatal_alert(AlertDescription::HandshakeFailure);
.send_fatal_alert(AlertDescription::HandshakeFailure)
.await;
return Err(Error::PeerIncompatibleError(
"server sent bad certreq schemes".to_string(),
));
@@ -590,7 +606,9 @@ impl State<ClientConnectionData> for ExpectCertificate {
// This is only non-empty for client auth.
if !cert_chain.context.0.is_empty() {
warn!("certificate with non-empty context during handshake");
cx.common.send_fatal_alert(AlertDescription::DecodeError);
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
@@ -599,7 +617,8 @@ impl State<ClientConnectionData> for ExpectCertificate {
{
warn!("certificate chain contains unsolicited/unknown extension");
cx.common
.send_fatal_alert(AlertDescription::UnsupportedExtension);
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
return Err(Error::PeerMisbehavedError(
"bad cert chain extensions".to_string(),
));
@@ -670,30 +689,28 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
.split_first()
.ok_or(Error::NoCertificatesPresented)?;
let now = std::time::SystemTime::now();
let cert_verified = self
.config
.verifier
.verify_server_cert(
end_entity,
intermediates,
&self.server_name,
&mut self.server_cert.scts(),
&self.server_cert.ocsp_response,
now,
)
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?;
let cert_verified = match self.config.verifier.verify_server_cert(
end_entity,
intermediates,
&self.server_name,
&mut self.server_cert.scts(),
&self.server_cert.ocsp_response,
now,
) {
Ok(cert_verified) => cert_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
};
// 2. Verify their signature on the handshake.
let handshake_hash = self.transcript.get_current_hash();
let sig_verified = self
.config
.verifier
.verify_tls13_signature(
&verify::construct_tls13_server_verify_message(&handshake_hash),
&self.server_cert.cert_chain[0],
cert_verify,
)
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?;
let sig_verified = match self.config.verifier.verify_tls13_signature(
&verify::construct_tls13_server_verify_message(&handshake_hash),
&self.server_cert.cert_chain[0],
cert_verify,
) {
Ok(sig_verified) => sig_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
};
cx.common.peer_certificates = Some(self.server_cert.cert_chain);
self.transcript.add_message(&m);
@@ -712,7 +729,7 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
}
}
fn emit_certificate_tls13(
async fn emit_certificate_tls13(
transcript: &mut HandshakeHash,
certkey: Option<&CertifiedKey>,
auth_context: Option<Vec<u8>>,
@@ -741,10 +758,10 @@ fn emit_certificate_tls13(
}),
};
transcript.add_message(&m);
common.send_msg(m, true);
common.send_msg(m, true).await;
}
fn emit_certverify_tls13(
async fn emit_certverify_tls13(
transcript: &mut HandshakeHash,
signer: &dyn Signer,
common: &mut CommonState,
@@ -764,11 +781,11 @@ fn emit_certverify_tls13(
};
transcript.add_message(&m);
common.send_msg(m, true);
common.send_msg(m, true).await;
Ok(())
}
fn emit_finished_tls13(
async fn emit_finished_tls13(
transcript: &mut HandshakeHash,
verify_data: ring::hmac::Tag,
common: &mut CommonState,
@@ -784,10 +801,10 @@ fn emit_finished_tls13(
};
transcript.add_message(&m);
common.send_msg(m, true);
common.send_msg(m, true).await;
}
fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut CommonState) {
async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut CommonState) {
let m = Message {
version: ProtocolVersion::TLSv1_3,
payload: MessagePayload::Handshake(HandshakeMessagePayload {
@@ -797,7 +814,7 @@ fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut Com
};
transcript.add_message(&m);
common.send_msg(m, true);
common.send_msg(m, true).await;
}
struct ExpectFinished {
@@ -826,12 +843,18 @@ impl State<ClientConnectionData> for ExpectFinished {
let handshake_hash = st.transcript.get_current_hash();
let expect_verify_data = st.key_schedule.sign_server_finish(&handshake_hash);
let fin = constant_time::verify_slices_are_equal(expect_verify_data.as_ref(), &finished.0)
.map_err(|_| {
cx.common.send_fatal_alert(AlertDescription::DecryptError);
Error::DecryptError
})
.map(|_| verify::FinishedMessageVerified::assertion())?;
let fin = match constant_time::verify_slices_are_equal(
expect_verify_data.as_ref(),
&finished.0,
) {
Ok(()) => verify::FinishedMessageVerified::assertion(),
Err(_) => {
cx.common
.send_fatal_alert(AlertDescription::DecryptError)
.await;
return Err(Error::DecryptError);
}
};
st.transcript.add_message(&m);
@@ -839,7 +862,7 @@ impl State<ClientConnectionData> for ExpectFinished {
/* The EndOfEarlyData message to server is still encrypted with early data keys,
* but appears in the transcript after the server Finished. */
if cx.common.early_traffic {
emit_end_of_early_data_tls13(&mut st.transcript, cx.common);
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await;
cx.common.early_traffic = false;
cx.data.early_data.finished();
cx.common
@@ -854,7 +877,7 @@ impl State<ClientConnectionData> for ExpectFinished {
ClientAuthDetails::Empty {
auth_context_tls13: auth_context,
} => {
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common);
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common).await;
}
ClientAuthDetails::Verify {
certkey,
@@ -866,8 +889,9 @@ impl State<ClientConnectionData> for ExpectFinished {
Some(&certkey),
auth_context,
cx.common,
);
emit_certverify_tls13(&mut st.transcript, signer.as_ref(), cx.common)?;
)
.await;
emit_certverify_tls13(&mut st.transcript, signer.as_ref(), cx.common).await?;
}
}
}
@@ -881,10 +905,10 @@ impl State<ClientConnectionData> for ExpectFinished {
let handshake_hash = st.transcript.get_current_hash();
let (key_schedule_traffic, verify_data, _) =
key_schedule_finished.sign_client_finish(&handshake_hash);
emit_finished_tls13(&mut st.transcript, verify_data, cx.common);
emit_finished_tls13(&mut st.transcript, verify_data, cx.common).await;
/* Now move to our application traffic keys. */
cx.common.check_aligned_handshake()?;
cx.common.check_aligned_handshake().await?;
cx.common
.record_layer
@@ -936,7 +960,8 @@ impl ExpectTraffic {
) -> Result<(), Error> {
if nst.has_duplicate_extension() {
cx.common
.send_fatal_alert(AlertDescription::IllegalParameter);
.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
return Err(Error::PeerMisbehavedError(
"peer sent duplicate NewSessionTicket extensions".into(),
));
@@ -987,7 +1012,7 @@ impl ExpectTraffic {
kur: &KeyUpdateRequest,
) -> Result<(), Error> {
// Mustn't be interleaved with other handshake messages.
common.check_aligned_handshake()?;
common.check_aligned_handshake().await?;
match kur {
KeyUpdateRequest::UpdateNotRequested => {}
@@ -995,7 +1020,9 @@ impl ExpectTraffic {
self.want_write_key_update = true;
}
_ => {
common.send_fatal_alert(AlertDescription::IllegalParameter);
common
.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
}
@@ -1049,10 +1076,12 @@ impl State<ClientConnectionData> for ExpectTraffic {
.export_keying_material(output, label, context)
}
fn perhaps_write_key_update(&mut self, common: &mut CommonState) {
async fn perhaps_write_key_update(&mut self, common: &mut CommonState) {
if self.want_write_key_update {
self.want_write_key_update = false;
common.send_msg_encrypt(Message::build_key_update_notify().into());
common
.send_msg_encrypt(Message::build_key_update_notify().into())
.await;
let write_key = self.key_schedule.next_client_application_traffic_secret();
common

View File

@@ -21,12 +21,15 @@ use crate::tls12::ConnectionSecrets;
use crate::vecbuf::ChunkVecBuffer;
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::ready;
use std::collections::VecDeque;
use std::convert::TryFrom;
use std::io;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::Arc;
use std::task;
use tokio::io::AsyncWrite;
use tokio_util::sync::ReusableBoxFuture;
@@ -156,62 +159,6 @@ impl<'a> io::Read for Reader<'a> {
Ok(())
}
}
/// A structure that implements [`tokio::io::AsyncWrite`] for writing plaintext.
pub struct Writer<'a> {
con: &'a mut ConnectionCommon,
fut: Option<ReusableBoxFuture<'a, usize>>,
}
impl<'a> Writer<'a> {
/// Create a new Writer.
///
/// This is not an external interface. Get one of these objects
/// from [`Connection::writer`].
#[doc(hidden)]
pub fn new(con: &'a mut ConnectionCommon) -> Writer<'a> {
Writer { con, fut: None }
}
}
impl<'a> AsyncWrite for Writer<'a> {
/// Send the plaintext `buf` to the peer, encrypting
/// and authenticating it. Once this function succeeds
/// you should call [`CommonState::write_tls`] which will output the
/// corresponding TLS records.
///
/// This function buffers plaintext sent before the
/// TLS handshake completes, and sends it as soon
/// as it can. See [`CommonState::set_buffer_limit`] to control
/// the size of this buffer.
fn poll_write(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let fut = match self.fut.as_mut() {
Some(fut) => fut,
None => {
let fut = self.con.send_some_plaintext(buf);
self.fut.get_or_insert(ReusableBoxFuture::new(fut))
}
};
match fut.poll(cx) {
task::Poll::Ready(sz) => task::Poll::Ready(Ok(sz)),
task::Poll::Pending => task::Poll::Pending,
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), io::Error>> {
task::Poll::Ready(Ok(()))
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub(crate) enum Protocol {
@@ -258,6 +205,7 @@ pub struct ConnectionCommon {
pub(crate) common_state: CommonState,
message_deframer: MessageDeframer,
handshake_joiner: HandshakeJoiner,
write_fut: Option<ReusableBoxFuture<'static, usize>>,
}
impl ConnectionCommon {
@@ -272,6 +220,7 @@ impl ConnectionCommon {
common_state,
message_deframer: MessageDeframer::new(),
handshake_joiner: HandshakeJoiner::new(),
write_fut: None,
}
}
@@ -287,9 +236,9 @@ impl ConnectionCommon {
}
}
/// Returns an object that allows writing plaintext.
pub fn writer(&mut self) -> Writer {
Writer::new(self)
/// Writes plaintext
pub async fn write(&mut self, buf: &[u8]) -> usize {
self.send_some_plaintext(buf).await
}
/// This function uses `io` to complete any outstanding IO for
@@ -372,7 +321,7 @@ impl ConnectionCommon {
///
/// This is a shortcut to the `process_new_packets()` -> `process_msg()` ->
/// `process_handshake_messages()` path, specialized for the first handshake message.
pub(crate) fn first_handshake_message(&mut self) -> Result<Option<Message>, Error> {
pub(crate) async fn first_handshake_message(&mut self) -> Result<Option<Message>, Error> {
if self.message_deframer.desynced {
return Err(Error::CorruptMessage);
}
@@ -389,7 +338,8 @@ impl ConnectionCommon {
if self.handshake_joiner.take_message(msg).is_none() {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError);
.send_fatal_alert(AlertDescription::DecodeError)
.await;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
@@ -418,7 +368,8 @@ impl ConnectionCommon {
// which receives a protected change_cipher_spec record MUST abort the
// handshake with an "unexpected_message" alert."
self.common_state
.send_fatal_alert(AlertDescription::UnexpectedMessage);
.send_fatal_alert(AlertDescription::UnexpectedMessage)
.await;
return Err(Error::PeerMisbehavedError(
"illegal middlebox CCS received".into(),
));
@@ -450,11 +401,15 @@ impl ConnectionCommon {
// First decryptable handshake message concludes trial decryption
self.common_state.record_layer.finish_trial_decryption();
self.handshake_joiner.take_message(msg).ok_or_else(|| {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError);
Error::CorruptMessagePayload(ContentType::Handshake)
})?;
match self.handshake_joiner.take_message(msg) {
Some(_) => {}
None => {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError)
.await;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
}
return self.process_new_handshake_messages(state).await;
}
@@ -463,7 +418,7 @@ impl ConnectionCommon {
// For alerts, we have separate logic.
if let MessagePayload::Alert(alert) = &msg.payload {
self.common_state.process_alert(alert)?;
self.common_state.process_alert(alert).await?;
return Ok(state);
}
@@ -534,7 +489,7 @@ impl ConnectionCommon {
pub(crate) async fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
if let Ok(st) = &mut self.state {
st.perhaps_write_key_update(&mut self.common_state);
st.perhaps_write_key_update(&mut self.common_state).await;
}
self.common_state.send_some_plaintext(buf).await
}
@@ -730,7 +685,8 @@ impl CommonState {
Side::Server => HandshakeType::ClientHello,
};
if msg.is_handshake_type(reject_ty) {
self.send_warning_alert(AlertDescription::NoRenegotiation);
self.send_warning_alert(AlertDescription::NoRenegotiation)
.await;
return Ok(state);
}
}
@@ -743,7 +699,8 @@ impl CommonState {
}
Err(e @ Error::InappropriateMessage { .. })
| Err(e @ Error::InappropriateHandshakeMessage { .. }) => {
self.send_fatal_alert(AlertDescription::UnexpectedMessage);
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
.await;
Err(e)
}
Err(e) => Err(e),
@@ -775,9 +732,10 @@ impl CommonState {
// messages. Otherwise the defragmented messages will have
// been protected with two different record layer protections,
// which is illegal. Not mentioned in RFC.
pub(crate) fn check_aligned_handshake(&mut self) -> Result<(), Error> {
pub(crate) async fn check_aligned_handshake(&mut self) -> Result<(), Error> {
if !self.aligned_handshake {
self.send_fatal_alert(AlertDescription::UnexpectedMessage);
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
.await;
Err(Error::PeerMisbehavedError(
"key epoch or handshake flight with pending fragment".to_string(),
))
@@ -786,8 +744,9 @@ impl CommonState {
}
}
pub(crate) fn illegal_param(&mut self, why: &str) -> Error {
self.send_fatal_alert(AlertDescription::IllegalParameter);
pub(crate) async fn illegal_param(&mut self, why: &str) -> Error {
self.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
Error::PeerMisbehavedError(why.to_string())
}
@@ -796,7 +755,7 @@ impl CommonState {
encr: OpaqueMessage,
) -> Result<Option<PlainMessage>, Error> {
if self.record_layer.wants_close_before_decrypt() {
self.send_close_notify();
self.send_close_notify().await;
}
let encrypted_len = encr.payload.0.len();
@@ -804,7 +763,8 @@ impl CommonState {
match plain {
Err(Error::PeerSentOversizedRecord) => {
self.send_fatal_alert(AlertDescription::RecordOverflow);
self.send_fatal_alert(AlertDescription::RecordOverflow)
.await;
Err(Error::PeerSentOversizedRecord)
}
Err(Error::DecryptError) if self.record_layer.doing_trial_decryption(encrypted_len) => {
@@ -812,7 +772,7 @@ impl CommonState {
Ok(None)
}
Err(Error::DecryptError) => {
self.send_fatal_alert(AlertDescription::BadRecordMac);
self.send_fatal_alert(AlertDescription::BadRecordMac).await;
Err(Error::DecryptError)
}
Err(e) => Err(e),
@@ -861,7 +821,7 @@ impl CommonState {
// Close connection once we start to run out of
// sequence space.
if self.record_layer.wants_close_before_encrypt() {
self.send_close_notify();
self.send_close_notify().await;
}
// Refuse to wrap counter at all costs. This
@@ -967,13 +927,13 @@ impl CommonState {
/// Send any buffered plaintext. Plaintext is buffered if
/// written during handshake.
fn flush_plaintext(&mut self) {
async fn flush_plaintext(&mut self) {
if !self.may_send_application_data {
return;
}
while let Some(buf) = self.sendable_plaintext.pop() {
self.send_plain(&buf, Limit::No);
self.send_plain(&buf, Limit::No).await;
}
}
@@ -983,7 +943,7 @@ impl CommonState {
}
/// Send a raw TLS message, fragmenting it if needed.
pub(crate) fn send_msg(&mut self, m: Message, must_encrypt: bool) {
pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) {
if !must_encrypt {
let mut to_send = VecDeque::new();
self.message_fragmenter.fragment(m.into(), &mut to_send);
@@ -991,7 +951,7 @@ impl CommonState {
self.queue_tls_message(mm.into_unencrypted_opaque());
}
} else {
self.send_msg_encrypt(m.into());
self.send_msg_encrypt(m.into()).await;
}
}
@@ -1006,15 +966,16 @@ impl CommonState {
self.record_layer.prepare_message_decrypter(dec);
}
fn send_warning_alert(&mut self, desc: AlertDescription) {
async fn send_warning_alert(&mut self, desc: AlertDescription) {
warn!("Sending warning alert {:?}", desc);
self.send_warning_alert_no_log(desc);
self.send_warning_alert_no_log(desc).await;
}
fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
async fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
// Reject unknown AlertLevels.
if let AlertLevel::Unknown(_) = alert.level {
self.send_fatal_alert(AlertDescription::IllegalParameter);
self.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
}
// If we get a CloseNotify, make a note to declare EOF to our
@@ -1028,7 +989,7 @@ impl CommonState {
// (except, for no good reason, user_cancelled).
if alert.level == AlertLevel::Warning {
if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
self.send_fatal_alert(AlertDescription::DecodeError);
self.send_fatal_alert(AlertDescription::DecodeError).await;
} else {
warn!("TLS alert warning received: {:#?}", alert);
return Ok(());
@@ -1039,25 +1000,26 @@ impl CommonState {
Err(Error::AlertReceived(alert.description))
}
pub(crate) fn send_fatal_alert(&mut self, desc: AlertDescription) {
pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) {
warn!("Sending fatal alert {:?}", desc);
debug_assert!(!self.sent_fatal_alert);
let m = Message::build_alert(AlertLevel::Fatal, desc);
self.send_msg(m, self.record_layer.is_encrypting());
self.send_msg(m, self.record_layer.is_encrypting()).await;
self.sent_fatal_alert = true;
}
/// Queues a close_notify warning alert to be sent in the next
/// [`CommonState::write_tls`] call. This informs the peer that the
/// connection is being closed.
pub fn send_close_notify(&mut self) {
pub async fn send_close_notify(&mut self) {
debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
self.send_warning_alert_no_log(AlertDescription::CloseNotify);
self.send_warning_alert_no_log(AlertDescription::CloseNotify)
.await;
}
fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
let m = Message::build_alert(AlertLevel::Warning, desc);
self.send_msg(m, self.record_layer.is_encrypting());
self.send_msg(m, self.record_layer.is_encrypting()).await;
}
pub(crate) fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
@@ -1112,7 +1074,7 @@ pub(crate) trait State<ClientConnectionData>: Send + Sync {
Err(Error::HandshakeNotComplete)
}
fn perhaps_write_key_update(&mut self, _cx: &mut CommonState) {}
async fn perhaps_write_key_update(&mut self, _cx: &mut CommonState) {}
}
pub(crate) struct Context<'a> {

View File

@@ -357,7 +357,7 @@ pub use crate::anchors::{OwnedTrustAnchor, RootCertStore};
pub use crate::builder::{
ConfigBuilder, WantsCipherSuites, WantsKxGroups, WantsVerifier, WantsVersions,
};
pub use crate::conn::{CommonState, ConnectionCommon, IoState, Reader, SideData, Writer};
pub use crate::conn::{CommonState, ConnectionCommon, IoState, Reader, SideData};
pub use crate::error::Error;
pub use crate::key::{Certificate, PrivateKey};
pub use crate::key_log::{KeyLog, NoKeyLog};
@@ -393,7 +393,7 @@ pub mod client {
pub use client_conn::ResolvesClientCert;
pub use client_conn::ServerName;
pub use client_conn::StoresClientSessions;
pub use client_conn::{ClientConfig, ClientConnection, ClientConnectionData, WriteEarlyData};
pub use client_conn::{ClientConfig, ClientConnection, ClientConnectionData};
pub use handy::{ClientSessionMemoryCache, NoClientSessionStorage};
#[cfg(feature = "dangerous_configuration")]

View File

@@ -104,7 +104,7 @@ pub trait ServerCertVerifier: Send + Sync {
end_entity: &Certificate,
intermediates: &[Certificate],
server_name: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
scts: &mut (dyn Iterator<Item = &[u8]> + Send),
ocsp_response: &[u8],
now: SystemTime,
) -> Result<ServerCertVerified, Error>;
@@ -295,7 +295,7 @@ impl ServerCertVerifier for WebPkiVerifier {
end_entity: &Certificate,
intermediates: &[Certificate],
server_name: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
scts: &mut (dyn Iterator<Item = &[u8]> + Send),
ocsp_response: &[u8],
now: SystemTime,
) -> Result<ServerCertVerified, Error> {