mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 21:38:00 -05:00
more work on async refactor
This commit is contained in:
@@ -371,46 +371,6 @@ impl EarlyData {
|
||||
}
|
||||
}
|
||||
|
||||
/// Stub that implements io::Write and dispatches to `write_early_data`.
|
||||
pub struct WriteEarlyData<'a> {
|
||||
sess: &'a mut ClientConnection,
|
||||
fut: Option<futures::future::BoxFuture<'a, io::Result<usize>>>,
|
||||
}
|
||||
|
||||
impl<'a> WriteEarlyData<'a> {
|
||||
fn new(sess: &'a mut ClientConnection) -> WriteEarlyData<'a> {
|
||||
WriteEarlyData { sess, fut: None }
|
||||
}
|
||||
|
||||
/// How many bytes you may send. Writes will become short
|
||||
/// once this reaches zero.
|
||||
pub fn bytes_left(&self) -> usize {
|
||||
self.sess.inner.data.early_data.bytes_left()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsyncWrite for WriteEarlyData<'a> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
let fut = match self.fut {
|
||||
Some(fut) => fut,
|
||||
None => Box::pin(self.sess.write_early_data(buf)),
|
||||
};
|
||||
fut.as_mut().poll(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
/// This represents a single TLS client connection.
|
||||
pub struct ClientConnection {
|
||||
inner: ConnectionCommon,
|
||||
@@ -451,31 +411,31 @@ impl ClientConnection {
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Returns an `io::Write` implementer you can write bytes to
|
||||
/// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server.
|
||||
///
|
||||
/// This returns None in many circumstances when the capability to
|
||||
/// send early data is not available, including but not limited to:
|
||||
///
|
||||
/// - The server hasn't been talked to previously.
|
||||
/// - The server does not support resumption.
|
||||
/// - The server does not support early data.
|
||||
/// - The resumption data for the server has expired.
|
||||
///
|
||||
/// The server specifies a maximum amount of early data. You can
|
||||
/// learn this limit through the returned object, and writes through
|
||||
/// it will process only this many bytes.
|
||||
///
|
||||
/// The server can choose not to accept any sent early data --
|
||||
/// in this case the data is lost but the connection continues. You
|
||||
/// can tell this happened using `is_early_data_accepted`.
|
||||
pub fn early_data(&mut self) -> Option<WriteEarlyData> {
|
||||
if self.inner.data.early_data.is_enabled() {
|
||||
Some(WriteEarlyData::new(self))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
// /// Returns an `io::Write` implementer you can write bytes to
|
||||
// /// to send TLS1.3 early data (a.k.a. "0-RTT data") to the server.
|
||||
// ///
|
||||
// /// This returns None in many circumstances when the capability to
|
||||
// /// send early data is not available, including but not limited to:
|
||||
// ///
|
||||
// /// - The server hasn't been talked to previously.
|
||||
// /// - The server does not support resumption.
|
||||
// /// - The server does not support early data.
|
||||
// /// - The resumption data for the server has expired.
|
||||
// ///
|
||||
// /// The server specifies a maximum amount of early data. You can
|
||||
// /// learn this limit through the returned object, and writes through
|
||||
// /// it will process only this many bytes.
|
||||
// ///
|
||||
// /// The server can choose not to accept any sent early data --
|
||||
// /// in this case the data is lost but the connection continues. You
|
||||
// /// can tell this happened using `is_early_data_accepted`.
|
||||
// pub fn early_data(&mut self) -> Option<WriteEarlyData> {
|
||||
// if self.inner.data.early_data.is_enabled() {
|
||||
// Some(WriteEarlyData::new(self))
|
||||
// } else {
|
||||
// None
|
||||
// }
|
||||
// }
|
||||
|
||||
/// Returns True if the server signalled it will process early data.
|
||||
///
|
||||
|
||||
@@ -256,27 +256,29 @@ async fn emit_client_hello_for_retry(
|
||||
&& resume_version == ProtocolVersion::TLSv1_3
|
||||
&& !ticket.is_empty()
|
||||
{
|
||||
resuming_session
|
||||
.as_ref()
|
||||
.and_then(|resuming| match (suite, resuming.tls13()) {
|
||||
(Some(suite), Some(resuming)) => {
|
||||
suite.tls13()?.can_resume_from(resuming.suite())?;
|
||||
Some(resuming)
|
||||
}
|
||||
(None, Some(resuming)) => Some(resuming),
|
||||
_ => None,
|
||||
})
|
||||
.map(|resuming| {
|
||||
tls13::prepare_resumption(
|
||||
&config,
|
||||
cx,
|
||||
ticket,
|
||||
&resuming,
|
||||
&mut exts,
|
||||
retryreq.is_some(),
|
||||
);
|
||||
resuming
|
||||
})
|
||||
let resuming =
|
||||
resuming_session
|
||||
.as_ref()
|
||||
.and_then(|resuming| match (suite, resuming.tls13()) {
|
||||
(Some(suite), Some(resuming)) => {
|
||||
suite.tls13()?.can_resume_from(resuming.suite())?;
|
||||
Some(resuming)
|
||||
}
|
||||
(None, Some(resuming)) => Some(resuming),
|
||||
_ => None,
|
||||
});
|
||||
if let Some(ref resuming) = resuming {
|
||||
tls13::prepare_resumption(
|
||||
&config,
|
||||
cx,
|
||||
ticket,
|
||||
&resuming,
|
||||
&mut exts,
|
||||
retryreq.is_some(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
resuming
|
||||
} else if config.enable_tickets {
|
||||
// If we have a ticket, include it. Otherwise, request one.
|
||||
if ticket.is_empty() {
|
||||
@@ -333,31 +335,33 @@ async fn emit_client_hello_for_retry(
|
||||
if retryreq.is_some() {
|
||||
// send dummy CCS to fool middleboxes prior
|
||||
// to second client hello
|
||||
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common);
|
||||
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
|
||||
}
|
||||
|
||||
trace!("Sending ClientHello {:#?}", ch);
|
||||
|
||||
transcript_buffer.add_message(&ch);
|
||||
cx.common.send_msg(ch, false);
|
||||
cx.common.send_msg(ch, false).await;
|
||||
|
||||
// Calculate the hash of ClientHello and use it to derive EarlyTrafficSecret
|
||||
let early_key_schedule = early_key_schedule.map(|(resuming_suite, schedule)| {
|
||||
if !cx.data.early_data.is_enabled() {
|
||||
return schedule;
|
||||
let early_key_schedule = match early_key_schedule {
|
||||
Some((resuming_suite, schedule)) => {
|
||||
if cx.data.early_data.is_enabled() {
|
||||
tls13::derive_early_traffic_secret(
|
||||
&*config.key_log,
|
||||
cx,
|
||||
resuming_suite,
|
||||
&schedule,
|
||||
&mut sent_tls13_fake_ccs,
|
||||
&transcript_buffer,
|
||||
&random.0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Some(schedule)
|
||||
}
|
||||
|
||||
tls13::derive_early_traffic_secret(
|
||||
&*config.key_log,
|
||||
cx,
|
||||
resuming_suite,
|
||||
&schedule,
|
||||
&mut sent_tls13_fake_ccs,
|
||||
&transcript_buffer,
|
||||
&random.0,
|
||||
);
|
||||
schedule
|
||||
});
|
||||
None => None,
|
||||
};
|
||||
|
||||
let next = ExpectServerHello {
|
||||
config,
|
||||
@@ -381,7 +385,7 @@ async fn emit_client_hello_for_retry(
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn process_alpn_protocol(
|
||||
pub(super) async fn process_alpn_protocol(
|
||||
common: &mut CommonState,
|
||||
config: &ClientConfig,
|
||||
proto: Option<&[u8]>,
|
||||
@@ -390,7 +394,9 @@ pub(super) fn process_alpn_protocol(
|
||||
|
||||
if let Some(alpn_protocol) = &common.alpn_protocol {
|
||||
if !config.alpn_protocols.contains(alpn_protocol) {
|
||||
return Err(common.illegal_param("server sent non-offered ALPN protocol"));
|
||||
return Err(common
|
||||
.illegal_param("server sent non-offered ALPN protocol")
|
||||
.await);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,14 +447,16 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
if server_hello.get_supported_versions().is_some() {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server chose v1.2 using v1.3 extension"));
|
||||
.illegal_param("server chose v1.2 using v1.3 extension")
|
||||
.await);
|
||||
}
|
||||
|
||||
TLSv1_2
|
||||
}
|
||||
_ => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::ProtocolVersion);
|
||||
.send_fatal_alert(AlertDescription::ProtocolVersion)
|
||||
.await;
|
||||
let msg = match server_version {
|
||||
TLSv1_2 | TLSv1_3 => "server's TLS version is disabled in client",
|
||||
_ => "server does not support TLS v1.2/v1.3",
|
||||
@@ -458,11 +466,16 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
};
|
||||
|
||||
if server_hello.compression_method != Compression::Null {
|
||||
return Err(cx.common.illegal_param("server chose non-Null compression"));
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server chose non-Null compression")
|
||||
.await);
|
||||
}
|
||||
|
||||
if server_hello.has_duplicate_extension() {
|
||||
cx.common.send_fatal_alert(AlertDescription::DecodeError);
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent duplicate extensions".to_string(),
|
||||
));
|
||||
@@ -474,7 +487,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
.server_sent_unsolicited_extensions(&server_hello.extensions, &allowed_unsolicited)
|
||||
{
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension);
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent unsolicited extension".to_string(),
|
||||
));
|
||||
@@ -484,7 +498,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
|
||||
// Extract ALPN protocol
|
||||
if !cx.common.is_tls13() {
|
||||
process_alpn_protocol(cx.common, &self.config, server_hello.get_alpn_protocol())?;
|
||||
process_alpn_protocol(cx.common, &self.config, server_hello.get_alpn_protocol())
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If ECPointFormats extension is supplied by the server, it must contain
|
||||
@@ -492,33 +507,39 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
if let Some(point_fmts) = server_hello.get_ecpoints_extension() {
|
||||
if !point_fmts.contains(&ECPointFormat::Uncompressed) {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure);
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server does not support uncompressed points".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let suite = self
|
||||
.config
|
||||
.find_cipher_suite(server_hello.cipher_suite)
|
||||
.ok_or_else(|| {
|
||||
let suite = match self.config.find_cipher_suite(server_hello.cipher_suite) {
|
||||
Some(suite) => suite,
|
||||
None => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure);
|
||||
Error::PeerMisbehavedError("server chose non-offered ciphersuite".to_string())
|
||||
})?;
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server chose non-offered ciphersuite".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if version != suite.version().version {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server chose unusable ciphersuite for version"));
|
||||
.illegal_param("server chose unusable ciphersuite for version")
|
||||
.await);
|
||||
}
|
||||
|
||||
match self.suite {
|
||||
Some(prev_suite) if prev_suite != suite => {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server varied selected ciphersuite"));
|
||||
.illegal_param("server varied selected ciphersuite")
|
||||
.await);
|
||||
}
|
||||
_ => {
|
||||
debug!("Using ciphersuite {:?}", suite);
|
||||
@@ -602,7 +623,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
)?;
|
||||
trace!("Got HRR {:?}", hrr);
|
||||
|
||||
cx.common.check_aligned_handshake()?;
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
let cookie = hrr.get_cookie();
|
||||
let req_group = hrr.get_requested_key_share_group();
|
||||
@@ -615,7 +636,8 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
if cookie.is_none() && req_group == Some(offered_key_share.group()) {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with our group"));
|
||||
.illegal_param("server requested hrr with our group")
|
||||
.await);
|
||||
}
|
||||
|
||||
// Or has an empty cookie.
|
||||
@@ -623,14 +645,16 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
if cookie.0.is_empty() {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with empty cookie"));
|
||||
.illegal_param("server requested hrr with empty cookie")
|
||||
.await);
|
||||
}
|
||||
}
|
||||
|
||||
// Or has something unrecognised
|
||||
if hrr.has_unknown_extension() {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension);
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
return Err(Error::PeerIncompatibleError(
|
||||
"server sent hrr with unhandled extension".to_string(),
|
||||
));
|
||||
@@ -640,14 +664,16 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
if hrr.has_duplicate_extension() {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server send duplicate hrr extensions"));
|
||||
.illegal_param("server send duplicate hrr extensions")
|
||||
.await);
|
||||
}
|
||||
|
||||
// Or asks us to change nothing.
|
||||
if cookie.is_none() && req_group.is_none() {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with no changes"));
|
||||
.illegal_param("server requested hrr with no changes")
|
||||
.await);
|
||||
}
|
||||
|
||||
// Or asks us to talk a protocol we didn't offer, or doesn't support HRR at all.
|
||||
@@ -658,7 +684,8 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
_ => {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested unsupported version in hrr"));
|
||||
.illegal_param("server requested unsupported version in hrr")
|
||||
.await);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -669,7 +696,8 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
None => {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested unsupported cs in hrr"));
|
||||
.illegal_param("server requested unsupported cs in hrr")
|
||||
.await);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -690,11 +718,15 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
|
||||
let key_share = match req_group {
|
||||
Some(group) if group != offered_key_share.group() => {
|
||||
let group = kx::KeyExchange::choose(group, &self.next.config.kx_groups)
|
||||
.ok_or_else(|| {
|
||||
cx.common
|
||||
let group = match kx::KeyExchange::choose(group, &self.next.config.kx_groups) {
|
||||
Some(group) => group,
|
||||
None => {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with bad group")
|
||||
})?;
|
||||
.await)
|
||||
}
|
||||
};
|
||||
kx::KeyExchange::start(group).ok_or(Error::FailedToGetRandomBytes)?
|
||||
}
|
||||
_ => offered_key_share,
|
||||
@@ -742,16 +774,20 @@ impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn send_cert_error_alert(common: &mut CommonState, err: Error) -> Error {
|
||||
pub(super) async fn send_cert_error_alert(common: &mut CommonState, err: Error) -> Error {
|
||||
match err {
|
||||
Error::InvalidCertificateEncoding => {
|
||||
common.send_fatal_alert(AlertDescription::DecodeError);
|
||||
common.send_fatal_alert(AlertDescription::DecodeError).await;
|
||||
}
|
||||
Error::PeerMisbehavedError(_) => {
|
||||
common.send_fatal_alert(AlertDescription::IllegalParameter);
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
}
|
||||
_ => {
|
||||
common.send_fatal_alert(AlertDescription::BadCertificate);
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::BadCertificate)
|
||||
.await;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -67,7 +67,8 @@ mod server_hello {
|
||||
if tls13_supported && has_downgrade_marker {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("downgrade to TLS1.2 when TLS1.3 is supported"));
|
||||
.illegal_param("downgrade to TLS1.2 when TLS1.3 is supported")
|
||||
.await);
|
||||
}
|
||||
|
||||
// Doing EMS?
|
||||
@@ -408,10 +409,15 @@ impl State<ClientConnectionData> for ExpectServerKx {
|
||||
)?;
|
||||
self.transcript.add_message(&m);
|
||||
|
||||
let ecdhe = opaque_kx.unwrap_given_kxa(&self.suite.kx).ok_or_else(|| {
|
||||
cx.common.send_fatal_alert(AlertDescription::DecodeError);
|
||||
Error::CorruptMessagePayload(ContentType::Handshake)
|
||||
})?;
|
||||
let ecdhe = match opaque_kx.unwrap_given_kxa(&self.suite.kx) {
|
||||
Some(ecdhe) => ecdhe,
|
||||
None => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
};
|
||||
|
||||
// Save the signature and signed parameters for later verification.
|
||||
let mut kx_params = Vec::new();
|
||||
@@ -439,7 +445,7 @@ impl State<ClientConnectionData> for ExpectServerKx {
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_certificate(
|
||||
async fn emit_certificate(
|
||||
transcript: &mut HandshakeHash,
|
||||
cert_chain: CertificatePayload,
|
||||
common: &mut CommonState,
|
||||
@@ -453,10 +459,14 @@ fn emit_certificate(
|
||||
};
|
||||
|
||||
transcript.add_message(&cert);
|
||||
common.send_msg(cert, false);
|
||||
common.send_msg(cert, false).await;
|
||||
}
|
||||
|
||||
fn emit_clientkx(transcript: &mut HandshakeHash, common: &mut CommonState, pubkey: &PublicKey) {
|
||||
async fn emit_clientkx(
|
||||
transcript: &mut HandshakeHash,
|
||||
common: &mut CommonState,
|
||||
pubkey: &PublicKey,
|
||||
) {
|
||||
let mut buf = Vec::new();
|
||||
let ecpoint = PayloadU8::new(Vec::from(pubkey.as_ref()));
|
||||
ecpoint.encode(&mut buf);
|
||||
@@ -471,10 +481,10 @@ fn emit_clientkx(transcript: &mut HandshakeHash, common: &mut CommonState, pubke
|
||||
};
|
||||
|
||||
transcript.add_message(&ckx);
|
||||
common.send_msg(ckx, false);
|
||||
common.send_msg(ckx, false).await;
|
||||
}
|
||||
|
||||
fn emit_certverify(
|
||||
async fn emit_certverify(
|
||||
transcript: &mut HandshakeHash,
|
||||
signer: &dyn Signer,
|
||||
common: &mut CommonState,
|
||||
@@ -496,20 +506,20 @@ fn emit_certverify(
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, false);
|
||||
common.send_msg(m, false).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_ccs(common: &mut CommonState) {
|
||||
async fn emit_ccs(common: &mut CommonState) {
|
||||
let ccs = Message {
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
|
||||
};
|
||||
|
||||
common.send_msg(ccs, false);
|
||||
common.send_msg(ccs, false).await;
|
||||
}
|
||||
|
||||
fn emit_finished(
|
||||
async fn emit_finished(
|
||||
secrets: &ConnectionSecrets,
|
||||
transcript: &mut HandshakeHash,
|
||||
common: &mut CommonState,
|
||||
@@ -527,7 +537,7 @@ fn emit_finished(
|
||||
};
|
||||
|
||||
transcript.add_message(&f);
|
||||
common.send_msg(f, true);
|
||||
common.send_msg(f, true).await;
|
||||
}
|
||||
|
||||
struct ServerKxDetails {
|
||||
@@ -712,7 +722,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
let mut st = *self;
|
||||
st.transcript.add_message(&m);
|
||||
|
||||
cx.common.check_aligned_handshake()?;
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
trace!("Server cert is {:?}", st.server_cert.cert_chain);
|
||||
debug!("Server DNS name is {:?}", st.server_name);
|
||||
@@ -738,18 +748,17 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
.split_first()
|
||||
.ok_or(Error::NoCertificatesPresented)?;
|
||||
let now = std::time::SystemTime::now();
|
||||
let cert_verified = st
|
||||
.config
|
||||
.verifier
|
||||
.verify_server_cert(
|
||||
end_entity,
|
||||
intermediates,
|
||||
&st.server_name,
|
||||
&mut st.server_cert.scts(),
|
||||
&st.server_cert.ocsp_response,
|
||||
now,
|
||||
)
|
||||
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?;
|
||||
let cert_verified = match st.config.verifier.verify_server_cert(
|
||||
end_entity,
|
||||
intermediates,
|
||||
&st.server_name,
|
||||
&mut st.server_cert.scts(),
|
||||
&st.server_cert.ocsp_response,
|
||||
now,
|
||||
) {
|
||||
Ok(cert_verified) => cert_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
};
|
||||
|
||||
// 3.
|
||||
// Build up the contents of the signed message.
|
||||
@@ -772,10 +781,14 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
return Err(Error::PeerMisbehavedError(error_message));
|
||||
}
|
||||
|
||||
st.config
|
||||
.verifier
|
||||
.verify_tls12_signature(&message, &st.server_cert.cert_chain[0], sig)
|
||||
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?
|
||||
match st.config.verifier.verify_tls12_signature(
|
||||
&message,
|
||||
&st.server_cert.cert_chain[0],
|
||||
sig,
|
||||
) {
|
||||
Ok(sig_verified) => sig_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
}
|
||||
};
|
||||
cx.common.peer_certificates = Some(st.server_cert.cert_chain);
|
||||
|
||||
@@ -785,7 +798,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
ClientAuthDetails::Empty { .. } => Vec::new(),
|
||||
ClientAuthDetails::Verify { certkey, .. } => certkey.cert.clone(),
|
||||
};
|
||||
emit_certificate(&mut st.transcript, certs, cx.common);
|
||||
emit_certificate(&mut st.transcript, certs, cx.common).await;
|
||||
}
|
||||
|
||||
// 5a.
|
||||
@@ -800,17 +813,17 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
|
||||
// 5b.
|
||||
let mut transcript = st.transcript;
|
||||
emit_clientkx(&mut transcript, cx.common, &kx.pubkey);
|
||||
emit_clientkx(&mut transcript, cx.common, &kx.pubkey).await;
|
||||
// nb. EMS handshake hash only runs up to ClientKeyExchange.
|
||||
let ems_seed = st.using_ems.then(|| transcript.get_current_hash());
|
||||
|
||||
// 5c.
|
||||
if let Some(ClientAuthDetails::Verify { signer, .. }) = &st.client_auth {
|
||||
emit_certverify(&mut transcript, signer.as_ref(), cx.common)?;
|
||||
emit_certverify(&mut transcript, signer.as_ref(), cx.common).await?;
|
||||
}
|
||||
|
||||
// 5d.
|
||||
emit_ccs(cx.common);
|
||||
emit_ccs(cx.common).await;
|
||||
|
||||
// 5e. Now commit secrets.
|
||||
let secrets = ConnectionSecrets::from_key_exchange(
|
||||
@@ -830,7 +843,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
cx.common.record_layer.start_encrypting();
|
||||
|
||||
// 6.
|
||||
emit_finished(&secrets, &mut transcript, cx.common);
|
||||
emit_finished(&secrets, &mut transcript, cx.common).await;
|
||||
|
||||
if st.must_issue_new_ticket {
|
||||
Ok(Box::new(ExpectNewTicket {
|
||||
@@ -940,7 +953,7 @@ impl State<ClientConnectionData> for ExpectCcs {
|
||||
}
|
||||
// CCS should not be received interleaved with fragmented handshake-level
|
||||
// message.
|
||||
cx.common.check_aligned_handshake()?;
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
// nb. msgs layer validates trivial contents of CCS
|
||||
cx.common.record_layer.start_decrypting();
|
||||
@@ -1040,7 +1053,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
let finished =
|
||||
require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
|
||||
|
||||
cx.common.check_aligned_handshake()?;
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
// Work out what verify_data we expect.
|
||||
let vh = st.transcript.get_current_hash();
|
||||
@@ -1049,12 +1062,15 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
// Constant-time verification of this is relatively unimportant: they only
|
||||
// get one chance. But it can't hurt.
|
||||
let _fin_verified =
|
||||
constant_time::verify_slices_are_equal(&expect_verify_data, &finished.0)
|
||||
.map_err(|_| {
|
||||
cx.common.send_fatal_alert(AlertDescription::DecryptError);
|
||||
Error::DecryptError
|
||||
})
|
||||
.map(|_| verify::FinishedMessageVerified::assertion())?;
|
||||
match constant_time::verify_slices_are_equal(&expect_verify_data, &finished.0) {
|
||||
Ok(()) => verify::FinishedMessageVerified::assertion(),
|
||||
Err(_) => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecryptError)
|
||||
.await;
|
||||
return Err(Error::DecryptError);
|
||||
}
|
||||
};
|
||||
|
||||
// Hash this message too.
|
||||
st.transcript.add_message(&m);
|
||||
@@ -1062,9 +1078,9 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
st.save_session(cx);
|
||||
|
||||
if st.resuming {
|
||||
emit_ccs(cx.common);
|
||||
emit_ccs(cx.common).await;
|
||||
cx.common.record_layer.start_encrypting();
|
||||
emit_finished(&st.secrets, &mut st.transcript, cx.common);
|
||||
emit_finished(&st.secrets, &mut st.transcript, cx.common).await;
|
||||
}
|
||||
|
||||
cx.common.start_traffic();
|
||||
|
||||
@@ -35,6 +35,7 @@ use crate::client::common::{ClientAuthDetails, ClientHelloDetails};
|
||||
use crate::client::{hs, ClientConfig, ServerName, StoresClientSessions};
|
||||
|
||||
use crate::ticketer::TimeBase;
|
||||
use futures::Future;
|
||||
use ring::constant_time;
|
||||
|
||||
use crate::sign::{CertifiedKey, Signer};
|
||||
@@ -71,16 +72,20 @@ pub(super) async fn handle_server_hello(
|
||||
our_key_share: kx::KeyExchange,
|
||||
mut sent_tls13_fake_ccs: bool,
|
||||
) -> hs::NextStateOrError {
|
||||
validate_server_hello(cx.common, server_hello)?;
|
||||
validate_server_hello(cx.common, server_hello).await?;
|
||||
|
||||
let their_key_share = server_hello.get_key_share().ok_or_else(|| {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::MissingExtension);
|
||||
Error::PeerMisbehavedError("missing key share".to_string())
|
||||
})?;
|
||||
let their_key_share = match server_hello.get_key_share() {
|
||||
Some(ks) => ks,
|
||||
None => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::MissingExtension)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError("missing key share".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
if our_key_share.group() != their_key_share.group {
|
||||
return Err(cx.common.illegal_param("wrong group for key share"));
|
||||
return Err(cx.common.illegal_param("wrong group for key share").await);
|
||||
}
|
||||
|
||||
let key_schedule_pre_handshake = if let (Some(selected_psk), Some(early_key_schedule)) =
|
||||
@@ -92,7 +97,8 @@ pub(super) async fn handle_server_hello(
|
||||
None => {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server resuming incompatible suite"));
|
||||
.illegal_param("server resuming incompatible suite")
|
||||
.await);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -101,11 +107,12 @@ pub(super) async fn handle_server_hello(
|
||||
if cx.data.early_data.is_enabled() && resuming_suite != suite {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server varied suite with early data"));
|
||||
.illegal_param("server varied suite with early data")
|
||||
.await);
|
||||
}
|
||||
|
||||
if selected_psk != 0 {
|
||||
return Err(cx.common.illegal_param("server selected invalid psk"));
|
||||
return Err(cx.common.illegal_param("server selected invalid psk").await);
|
||||
}
|
||||
|
||||
debug!("Resuming using PSK");
|
||||
@@ -134,7 +141,7 @@ pub(super) async fn handle_server_hello(
|
||||
|
||||
// If we change keying when a subsequent handshake message is being joined,
|
||||
// the two halves will have different record layer protections. Disallow this.
|
||||
cx.common.check_aligned_handshake()?;
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
let hash_at_client_recvd_server_hello = transcript.get_current_hash();
|
||||
|
||||
@@ -156,7 +163,7 @@ pub(super) async fn handle_server_hello(
|
||||
.set_message_encrypter(suite.derive_encrypter(&client_key));
|
||||
}
|
||||
|
||||
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common);
|
||||
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
|
||||
|
||||
Ok(Box::new(ExpectEncryptedExtensions {
|
||||
config,
|
||||
@@ -170,13 +177,15 @@ pub(super) async fn handle_server_hello(
|
||||
}))
|
||||
}
|
||||
|
||||
fn validate_server_hello(
|
||||
async fn validate_server_hello(
|
||||
common: &mut CommonState,
|
||||
server_hello: &ServerHelloPayload,
|
||||
) -> Result<(), Error> {
|
||||
for ext in &server_hello.extensions {
|
||||
if !ALLOWED_PLAINTEXT_EXTS.contains(&ext.get_type()) {
|
||||
common.send_fatal_alert(AlertDescription::UnsupportedExtension);
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent unexpected cleartext ext".to_string(),
|
||||
));
|
||||
@@ -239,7 +248,7 @@ pub(super) fn fill_in_psk_binder(
|
||||
key_schedule
|
||||
}
|
||||
|
||||
pub(super) fn prepare_resumption(
|
||||
pub(super) async fn prepare_resumption(
|
||||
config: &ClientConfig,
|
||||
cx: &mut ClientContext<'_>,
|
||||
ticket: Vec<u8>,
|
||||
@@ -273,17 +282,17 @@ pub(super) fn prepare_resumption(
|
||||
exts.push(ClientExtension::PresharedKey(psk_ext));
|
||||
}
|
||||
|
||||
pub(super) fn derive_early_traffic_secret(
|
||||
pub(super) async fn derive_early_traffic_secret(
|
||||
key_log: &dyn KeyLog,
|
||||
cx: &mut ClientContext<'_>,
|
||||
resuming_suite: &'static Tls13CipherSuite,
|
||||
resuming_suite: &Tls13CipherSuite,
|
||||
early_key_schedule: &KeyScheduleEarly,
|
||||
sent_tls13_fake_ccs: &mut bool,
|
||||
transcript_buffer: &HandshakeHashBuffer,
|
||||
client_random: &[u8; 32],
|
||||
) {
|
||||
// For middlebox compatibility
|
||||
emit_fake_ccs(sent_tls13_fake_ccs, cx.common);
|
||||
emit_fake_ccs(sent_tls13_fake_ccs, cx.common).await;
|
||||
|
||||
let client_hello_hash = transcript_buffer.get_hash_given(resuming_suite.hash_algorithm(), &[]);
|
||||
let client_early_traffic_secret =
|
||||
@@ -298,7 +307,7 @@ pub(super) fn derive_early_traffic_secret(
|
||||
trace!("Starting early data traffic");
|
||||
}
|
||||
|
||||
pub(super) fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonState) {
|
||||
pub(super) async fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonState) {
|
||||
if std::mem::replace(sent_tls13_fake_ccs, true) {
|
||||
return;
|
||||
}
|
||||
@@ -307,23 +316,25 @@ pub(super) fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonS
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
|
||||
};
|
||||
common.send_msg(m, false);
|
||||
common.send_msg(m, false).await;
|
||||
}
|
||||
|
||||
fn validate_encrypted_extensions(
|
||||
async fn validate_encrypted_extensions(
|
||||
common: &mut CommonState,
|
||||
hello: &ClientHelloDetails,
|
||||
exts: &EncryptedExtensions,
|
||||
) -> Result<(), Error> {
|
||||
if exts.has_duplicate_extension() {
|
||||
common.send_fatal_alert(AlertDescription::DecodeError);
|
||||
common.send_fatal_alert(AlertDescription::DecodeError).await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent duplicate encrypted extensions".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if hello.server_sent_unsolicited_extensions(exts, &[]) {
|
||||
common.send_fatal_alert(AlertDescription::UnsupportedExtension);
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
let msg = "server sent unsolicited encrypted extension".to_string();
|
||||
return Err(Error::PeerMisbehavedError(msg));
|
||||
}
|
||||
@@ -332,7 +343,9 @@ fn validate_encrypted_extensions(
|
||||
if ALLOWED_PLAINTEXT_EXTS.contains(&ext.get_type())
|
||||
|| DISALLOWED_TLS13_EXTS.contains(&ext.get_type())
|
||||
{
|
||||
common.send_fatal_alert(AlertDescription::UnsupportedExtension);
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
let msg = "server sent inappropriate encrypted extension".to_string();
|
||||
return Err(Error::PeerMisbehavedError(msg));
|
||||
}
|
||||
@@ -367,8 +380,8 @@ impl State<ClientConnectionData> for ExpectEncryptedExtensions {
|
||||
debug!("TLS1.3 encrypted extensions: {:?}", exts);
|
||||
self.transcript.add_message(&m);
|
||||
|
||||
validate_encrypted_extensions(cx.common, &self.hello, exts)?;
|
||||
hs::process_alpn_protocol(cx.common, &self.config, exts.get_alpn_protocol())?;
|
||||
validate_encrypted_extensions(cx.common, &self.hello, exts).await?;
|
||||
hs::process_alpn_protocol(cx.common, &self.config, exts.get_alpn_protocol()).await?;
|
||||
|
||||
if let Some(resuming_session) = self.resuming_session {
|
||||
let was_early_traffic = cx.common.early_traffic;
|
||||
@@ -520,7 +533,9 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
// Must be empty during handshake.
|
||||
if !certreq.context.0.is_empty() {
|
||||
warn!("Server sent non-empty certreq context");
|
||||
cx.common.send_fatal_alert(AlertDescription::DecodeError);
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
|
||||
@@ -536,7 +551,8 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
|
||||
if compat_sigschemes.is_empty() {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure);
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure)
|
||||
.await;
|
||||
return Err(Error::PeerIncompatibleError(
|
||||
"server sent bad certreq schemes".to_string(),
|
||||
));
|
||||
@@ -590,7 +606,9 @@ impl State<ClientConnectionData> for ExpectCertificate {
|
||||
// This is only non-empty for client auth.
|
||||
if !cert_chain.context.0.is_empty() {
|
||||
warn!("certificate with non-empty context during handshake");
|
||||
cx.common.send_fatal_alert(AlertDescription::DecodeError);
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
|
||||
@@ -599,7 +617,8 @@ impl State<ClientConnectionData> for ExpectCertificate {
|
||||
{
|
||||
warn!("certificate chain contains unsolicited/unknown extension");
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension);
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"bad cert chain extensions".to_string(),
|
||||
));
|
||||
@@ -670,30 +689,28 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
|
||||
.split_first()
|
||||
.ok_or(Error::NoCertificatesPresented)?;
|
||||
let now = std::time::SystemTime::now();
|
||||
let cert_verified = self
|
||||
.config
|
||||
.verifier
|
||||
.verify_server_cert(
|
||||
end_entity,
|
||||
intermediates,
|
||||
&self.server_name,
|
||||
&mut self.server_cert.scts(),
|
||||
&self.server_cert.ocsp_response,
|
||||
now,
|
||||
)
|
||||
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?;
|
||||
let cert_verified = match self.config.verifier.verify_server_cert(
|
||||
end_entity,
|
||||
intermediates,
|
||||
&self.server_name,
|
||||
&mut self.server_cert.scts(),
|
||||
&self.server_cert.ocsp_response,
|
||||
now,
|
||||
) {
|
||||
Ok(cert_verified) => cert_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
};
|
||||
|
||||
// 2. Verify their signature on the handshake.
|
||||
let handshake_hash = self.transcript.get_current_hash();
|
||||
let sig_verified = self
|
||||
.config
|
||||
.verifier
|
||||
.verify_tls13_signature(
|
||||
&verify::construct_tls13_server_verify_message(&handshake_hash),
|
||||
&self.server_cert.cert_chain[0],
|
||||
cert_verify,
|
||||
)
|
||||
.map_err(|err| hs::send_cert_error_alert(cx.common, err))?;
|
||||
let sig_verified = match self.config.verifier.verify_tls13_signature(
|
||||
&verify::construct_tls13_server_verify_message(&handshake_hash),
|
||||
&self.server_cert.cert_chain[0],
|
||||
cert_verify,
|
||||
) {
|
||||
Ok(sig_verified) => sig_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
};
|
||||
|
||||
cx.common.peer_certificates = Some(self.server_cert.cert_chain);
|
||||
self.transcript.add_message(&m);
|
||||
@@ -712,7 +729,7 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_certificate_tls13(
|
||||
async fn emit_certificate_tls13(
|
||||
transcript: &mut HandshakeHash,
|
||||
certkey: Option<&CertifiedKey>,
|
||||
auth_context: Option<Vec<u8>>,
|
||||
@@ -741,10 +758,10 @@ fn emit_certificate_tls13(
|
||||
}),
|
||||
};
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true);
|
||||
common.send_msg(m, true).await;
|
||||
}
|
||||
|
||||
fn emit_certverify_tls13(
|
||||
async fn emit_certverify_tls13(
|
||||
transcript: &mut HandshakeHash,
|
||||
signer: &dyn Signer,
|
||||
common: &mut CommonState,
|
||||
@@ -764,11 +781,11 @@ fn emit_certverify_tls13(
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true);
|
||||
common.send_msg(m, true).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_finished_tls13(
|
||||
async fn emit_finished_tls13(
|
||||
transcript: &mut HandshakeHash,
|
||||
verify_data: ring::hmac::Tag,
|
||||
common: &mut CommonState,
|
||||
@@ -784,10 +801,10 @@ fn emit_finished_tls13(
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true);
|
||||
common.send_msg(m, true).await;
|
||||
}
|
||||
|
||||
fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut CommonState) {
|
||||
async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut CommonState) {
|
||||
let m = Message {
|
||||
version: ProtocolVersion::TLSv1_3,
|
||||
payload: MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
@@ -797,7 +814,7 @@ fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut Com
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true);
|
||||
common.send_msg(m, true).await;
|
||||
}
|
||||
|
||||
struct ExpectFinished {
|
||||
@@ -826,12 +843,18 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
let handshake_hash = st.transcript.get_current_hash();
|
||||
let expect_verify_data = st.key_schedule.sign_server_finish(&handshake_hash);
|
||||
|
||||
let fin = constant_time::verify_slices_are_equal(expect_verify_data.as_ref(), &finished.0)
|
||||
.map_err(|_| {
|
||||
cx.common.send_fatal_alert(AlertDescription::DecryptError);
|
||||
Error::DecryptError
|
||||
})
|
||||
.map(|_| verify::FinishedMessageVerified::assertion())?;
|
||||
let fin = match constant_time::verify_slices_are_equal(
|
||||
expect_verify_data.as_ref(),
|
||||
&finished.0,
|
||||
) {
|
||||
Ok(()) => verify::FinishedMessageVerified::assertion(),
|
||||
Err(_) => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecryptError)
|
||||
.await;
|
||||
return Err(Error::DecryptError);
|
||||
}
|
||||
};
|
||||
|
||||
st.transcript.add_message(&m);
|
||||
|
||||
@@ -839,7 +862,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
/* The EndOfEarlyData message to server is still encrypted with early data keys,
|
||||
* but appears in the transcript after the server Finished. */
|
||||
if cx.common.early_traffic {
|
||||
emit_end_of_early_data_tls13(&mut st.transcript, cx.common);
|
||||
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await;
|
||||
cx.common.early_traffic = false;
|
||||
cx.data.early_data.finished();
|
||||
cx.common
|
||||
@@ -854,7 +877,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
ClientAuthDetails::Empty {
|
||||
auth_context_tls13: auth_context,
|
||||
} => {
|
||||
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common);
|
||||
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common).await;
|
||||
}
|
||||
ClientAuthDetails::Verify {
|
||||
certkey,
|
||||
@@ -866,8 +889,9 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
Some(&certkey),
|
||||
auth_context,
|
||||
cx.common,
|
||||
);
|
||||
emit_certverify_tls13(&mut st.transcript, signer.as_ref(), cx.common)?;
|
||||
)
|
||||
.await;
|
||||
emit_certverify_tls13(&mut st.transcript, signer.as_ref(), cx.common).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -881,10 +905,10 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
let handshake_hash = st.transcript.get_current_hash();
|
||||
let (key_schedule_traffic, verify_data, _) =
|
||||
key_schedule_finished.sign_client_finish(&handshake_hash);
|
||||
emit_finished_tls13(&mut st.transcript, verify_data, cx.common);
|
||||
emit_finished_tls13(&mut st.transcript, verify_data, cx.common).await;
|
||||
|
||||
/* Now move to our application traffic keys. */
|
||||
cx.common.check_aligned_handshake()?;
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
cx.common
|
||||
.record_layer
|
||||
@@ -936,7 +960,8 @@ impl ExpectTraffic {
|
||||
) -> Result<(), Error> {
|
||||
if nst.has_duplicate_extension() {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::IllegalParameter);
|
||||
.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"peer sent duplicate NewSessionTicket extensions".into(),
|
||||
));
|
||||
@@ -987,7 +1012,7 @@ impl ExpectTraffic {
|
||||
kur: &KeyUpdateRequest,
|
||||
) -> Result<(), Error> {
|
||||
// Mustn't be interleaved with other handshake messages.
|
||||
common.check_aligned_handshake()?;
|
||||
common.check_aligned_handshake().await?;
|
||||
|
||||
match kur {
|
||||
KeyUpdateRequest::UpdateNotRequested => {}
|
||||
@@ -995,7 +1020,9 @@ impl ExpectTraffic {
|
||||
self.want_write_key_update = true;
|
||||
}
|
||||
_ => {
|
||||
common.send_fatal_alert(AlertDescription::IllegalParameter);
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
}
|
||||
@@ -1049,10 +1076,12 @@ impl State<ClientConnectionData> for ExpectTraffic {
|
||||
.export_keying_material(output, label, context)
|
||||
}
|
||||
|
||||
fn perhaps_write_key_update(&mut self, common: &mut CommonState) {
|
||||
async fn perhaps_write_key_update(&mut self, common: &mut CommonState) {
|
||||
if self.want_write_key_update {
|
||||
self.want_write_key_update = false;
|
||||
common.send_msg_encrypt(Message::build_key_update_notify().into());
|
||||
common
|
||||
.send_msg_encrypt(Message::build_key_update_notify().into())
|
||||
.await;
|
||||
|
||||
let write_key = self.key_schedule.next_client_application_traffic_secret();
|
||||
common
|
||||
|
||||
@@ -21,12 +21,15 @@ use crate::tls12::ConnectionSecrets;
|
||||
use crate::vecbuf::ChunkVecBuffer;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::ready;
|
||||
use std::collections::VecDeque;
|
||||
use std::convert::TryFrom;
|
||||
use std::io;
|
||||
use std::mem;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio_util::sync::ReusableBoxFuture;
|
||||
@@ -156,62 +159,6 @@ impl<'a> io::Read for Reader<'a> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
/// A structure that implements [`tokio::io::AsyncWrite`] for writing plaintext.
|
||||
pub struct Writer<'a> {
|
||||
con: &'a mut ConnectionCommon,
|
||||
fut: Option<ReusableBoxFuture<'a, usize>>,
|
||||
}
|
||||
|
||||
impl<'a> Writer<'a> {
|
||||
/// Create a new Writer.
|
||||
///
|
||||
/// This is not an external interface. Get one of these objects
|
||||
/// from [`Connection::writer`].
|
||||
#[doc(hidden)]
|
||||
pub fn new(con: &'a mut ConnectionCommon) -> Writer<'a> {
|
||||
Writer { con, fut: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsyncWrite for Writer<'a> {
|
||||
/// Send the plaintext `buf` to the peer, encrypting
|
||||
/// and authenticating it. Once this function succeeds
|
||||
/// you should call [`CommonState::write_tls`] which will output the
|
||||
/// corresponding TLS records.
|
||||
///
|
||||
/// This function buffers plaintext sent before the
|
||||
/// TLS handshake completes, and sends it as soon
|
||||
/// as it can. See [`CommonState::set_buffer_limit`] to control
|
||||
/// the size of this buffer.
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
let fut = match self.fut.as_mut() {
|
||||
Some(fut) => fut,
|
||||
None => {
|
||||
let fut = self.con.send_some_plaintext(buf);
|
||||
self.fut.get_or_insert(ReusableBoxFuture::new(fut))
|
||||
}
|
||||
};
|
||||
match fut.poll(cx) {
|
||||
task::Poll::Ready(sz) => task::Poll::Ready(Ok(sz)),
|
||||
task::Poll::Pending => task::Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
|
||||
task::Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> task::Poll<Result<(), io::Error>> {
|
||||
task::Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub(crate) enum Protocol {
|
||||
@@ -258,6 +205,7 @@ pub struct ConnectionCommon {
|
||||
pub(crate) common_state: CommonState,
|
||||
message_deframer: MessageDeframer,
|
||||
handshake_joiner: HandshakeJoiner,
|
||||
write_fut: Option<ReusableBoxFuture<'static, usize>>,
|
||||
}
|
||||
|
||||
impl ConnectionCommon {
|
||||
@@ -272,6 +220,7 @@ impl ConnectionCommon {
|
||||
common_state,
|
||||
message_deframer: MessageDeframer::new(),
|
||||
handshake_joiner: HandshakeJoiner::new(),
|
||||
write_fut: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,9 +236,9 @@ impl ConnectionCommon {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an object that allows writing plaintext.
|
||||
pub fn writer(&mut self) -> Writer {
|
||||
Writer::new(self)
|
||||
/// Writes plaintext
|
||||
pub async fn write(&mut self, buf: &[u8]) -> usize {
|
||||
self.send_some_plaintext(buf).await
|
||||
}
|
||||
|
||||
/// This function uses `io` to complete any outstanding IO for
|
||||
@@ -372,7 +321,7 @@ impl ConnectionCommon {
|
||||
///
|
||||
/// This is a shortcut to the `process_new_packets()` -> `process_msg()` ->
|
||||
/// `process_handshake_messages()` path, specialized for the first handshake message.
|
||||
pub(crate) fn first_handshake_message(&mut self) -> Result<Option<Message>, Error> {
|
||||
pub(crate) async fn first_handshake_message(&mut self) -> Result<Option<Message>, Error> {
|
||||
if self.message_deframer.desynced {
|
||||
return Err(Error::CorruptMessage);
|
||||
}
|
||||
@@ -389,7 +338,8 @@ impl ConnectionCommon {
|
||||
|
||||
if self.handshake_joiner.take_message(msg).is_none() {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
|
||||
@@ -418,7 +368,8 @@ impl ConnectionCommon {
|
||||
// which receives a protected change_cipher_spec record MUST abort the
|
||||
// handshake with an "unexpected_message" alert."
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::UnexpectedMessage);
|
||||
.send_fatal_alert(AlertDescription::UnexpectedMessage)
|
||||
.await;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"illegal middlebox CCS received".into(),
|
||||
));
|
||||
@@ -450,11 +401,15 @@ impl ConnectionCommon {
|
||||
// First decryptable handshake message concludes trial decryption
|
||||
self.common_state.record_layer.finish_trial_decryption();
|
||||
|
||||
self.handshake_joiner.take_message(msg).ok_or_else(|| {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
Error::CorruptMessagePayload(ContentType::Handshake)
|
||||
})?;
|
||||
match self.handshake_joiner.take_message(msg) {
|
||||
Some(_) => {}
|
||||
None => {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
}
|
||||
return self.process_new_handshake_messages(state).await;
|
||||
}
|
||||
|
||||
@@ -463,7 +418,7 @@ impl ConnectionCommon {
|
||||
|
||||
// For alerts, we have separate logic.
|
||||
if let MessagePayload::Alert(alert) = &msg.payload {
|
||||
self.common_state.process_alert(alert)?;
|
||||
self.common_state.process_alert(alert).await?;
|
||||
return Ok(state);
|
||||
}
|
||||
|
||||
@@ -534,7 +489,7 @@ impl ConnectionCommon {
|
||||
|
||||
pub(crate) async fn send_some_plaintext(&mut self, buf: &[u8]) -> usize {
|
||||
if let Ok(st) = &mut self.state {
|
||||
st.perhaps_write_key_update(&mut self.common_state);
|
||||
st.perhaps_write_key_update(&mut self.common_state).await;
|
||||
}
|
||||
self.common_state.send_some_plaintext(buf).await
|
||||
}
|
||||
@@ -730,7 +685,8 @@ impl CommonState {
|
||||
Side::Server => HandshakeType::ClientHello,
|
||||
};
|
||||
if msg.is_handshake_type(reject_ty) {
|
||||
self.send_warning_alert(AlertDescription::NoRenegotiation);
|
||||
self.send_warning_alert(AlertDescription::NoRenegotiation)
|
||||
.await;
|
||||
return Ok(state);
|
||||
}
|
||||
}
|
||||
@@ -743,7 +699,8 @@ impl CommonState {
|
||||
}
|
||||
Err(e @ Error::InappropriateMessage { .. })
|
||||
| Err(e @ Error::InappropriateHandshakeMessage { .. }) => {
|
||||
self.send_fatal_alert(AlertDescription::UnexpectedMessage);
|
||||
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
|
||||
.await;
|
||||
Err(e)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@@ -775,9 +732,10 @@ impl CommonState {
|
||||
// messages. Otherwise the defragmented messages will have
|
||||
// been protected with two different record layer protections,
|
||||
// which is illegal. Not mentioned in RFC.
|
||||
pub(crate) fn check_aligned_handshake(&mut self) -> Result<(), Error> {
|
||||
pub(crate) async fn check_aligned_handshake(&mut self) -> Result<(), Error> {
|
||||
if !self.aligned_handshake {
|
||||
self.send_fatal_alert(AlertDescription::UnexpectedMessage);
|
||||
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
|
||||
.await;
|
||||
Err(Error::PeerMisbehavedError(
|
||||
"key epoch or handshake flight with pending fragment".to_string(),
|
||||
))
|
||||
@@ -786,8 +744,9 @@ impl CommonState {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn illegal_param(&mut self, why: &str) -> Error {
|
||||
self.send_fatal_alert(AlertDescription::IllegalParameter);
|
||||
pub(crate) async fn illegal_param(&mut self, why: &str) -> Error {
|
||||
self.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
Error::PeerMisbehavedError(why.to_string())
|
||||
}
|
||||
|
||||
@@ -796,7 +755,7 @@ impl CommonState {
|
||||
encr: OpaqueMessage,
|
||||
) -> Result<Option<PlainMessage>, Error> {
|
||||
if self.record_layer.wants_close_before_decrypt() {
|
||||
self.send_close_notify();
|
||||
self.send_close_notify().await;
|
||||
}
|
||||
|
||||
let encrypted_len = encr.payload.0.len();
|
||||
@@ -804,7 +763,8 @@ impl CommonState {
|
||||
|
||||
match plain {
|
||||
Err(Error::PeerSentOversizedRecord) => {
|
||||
self.send_fatal_alert(AlertDescription::RecordOverflow);
|
||||
self.send_fatal_alert(AlertDescription::RecordOverflow)
|
||||
.await;
|
||||
Err(Error::PeerSentOversizedRecord)
|
||||
}
|
||||
Err(Error::DecryptError) if self.record_layer.doing_trial_decryption(encrypted_len) => {
|
||||
@@ -812,7 +772,7 @@ impl CommonState {
|
||||
Ok(None)
|
||||
}
|
||||
Err(Error::DecryptError) => {
|
||||
self.send_fatal_alert(AlertDescription::BadRecordMac);
|
||||
self.send_fatal_alert(AlertDescription::BadRecordMac).await;
|
||||
Err(Error::DecryptError)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@@ -861,7 +821,7 @@ impl CommonState {
|
||||
// Close connection once we start to run out of
|
||||
// sequence space.
|
||||
if self.record_layer.wants_close_before_encrypt() {
|
||||
self.send_close_notify();
|
||||
self.send_close_notify().await;
|
||||
}
|
||||
|
||||
// Refuse to wrap counter at all costs. This
|
||||
@@ -967,13 +927,13 @@ impl CommonState {
|
||||
|
||||
/// Send any buffered plaintext. Plaintext is buffered if
|
||||
/// written during handshake.
|
||||
fn flush_plaintext(&mut self) {
|
||||
async fn flush_plaintext(&mut self) {
|
||||
if !self.may_send_application_data {
|
||||
return;
|
||||
}
|
||||
|
||||
while let Some(buf) = self.sendable_plaintext.pop() {
|
||||
self.send_plain(&buf, Limit::No);
|
||||
self.send_plain(&buf, Limit::No).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -983,7 +943,7 @@ impl CommonState {
|
||||
}
|
||||
|
||||
/// Send a raw TLS message, fragmenting it if needed.
|
||||
pub(crate) fn send_msg(&mut self, m: Message, must_encrypt: bool) {
|
||||
pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) {
|
||||
if !must_encrypt {
|
||||
let mut to_send = VecDeque::new();
|
||||
self.message_fragmenter.fragment(m.into(), &mut to_send);
|
||||
@@ -991,7 +951,7 @@ impl CommonState {
|
||||
self.queue_tls_message(mm.into_unencrypted_opaque());
|
||||
}
|
||||
} else {
|
||||
self.send_msg_encrypt(m.into());
|
||||
self.send_msg_encrypt(m.into()).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1006,15 +966,16 @@ impl CommonState {
|
||||
self.record_layer.prepare_message_decrypter(dec);
|
||||
}
|
||||
|
||||
fn send_warning_alert(&mut self, desc: AlertDescription) {
|
||||
async fn send_warning_alert(&mut self, desc: AlertDescription) {
|
||||
warn!("Sending warning alert {:?}", desc);
|
||||
self.send_warning_alert_no_log(desc);
|
||||
self.send_warning_alert_no_log(desc).await;
|
||||
}
|
||||
|
||||
fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
|
||||
async fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
|
||||
// Reject unknown AlertLevels.
|
||||
if let AlertLevel::Unknown(_) = alert.level {
|
||||
self.send_fatal_alert(AlertDescription::IllegalParameter);
|
||||
self.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
}
|
||||
|
||||
// If we get a CloseNotify, make a note to declare EOF to our
|
||||
@@ -1028,7 +989,7 @@ impl CommonState {
|
||||
// (except, for no good reason, user_cancelled).
|
||||
if alert.level == AlertLevel::Warning {
|
||||
if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
|
||||
self.send_fatal_alert(AlertDescription::DecodeError);
|
||||
self.send_fatal_alert(AlertDescription::DecodeError).await;
|
||||
} else {
|
||||
warn!("TLS alert warning received: {:#?}", alert);
|
||||
return Ok(());
|
||||
@@ -1039,25 +1000,26 @@ impl CommonState {
|
||||
Err(Error::AlertReceived(alert.description))
|
||||
}
|
||||
|
||||
pub(crate) fn send_fatal_alert(&mut self, desc: AlertDescription) {
|
||||
pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) {
|
||||
warn!("Sending fatal alert {:?}", desc);
|
||||
debug_assert!(!self.sent_fatal_alert);
|
||||
let m = Message::build_alert(AlertLevel::Fatal, desc);
|
||||
self.send_msg(m, self.record_layer.is_encrypting());
|
||||
self.send_msg(m, self.record_layer.is_encrypting()).await;
|
||||
self.sent_fatal_alert = true;
|
||||
}
|
||||
|
||||
/// Queues a close_notify warning alert to be sent in the next
|
||||
/// [`CommonState::write_tls`] call. This informs the peer that the
|
||||
/// connection is being closed.
|
||||
pub fn send_close_notify(&mut self) {
|
||||
pub async fn send_close_notify(&mut self) {
|
||||
debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
|
||||
self.send_warning_alert_no_log(AlertDescription::CloseNotify);
|
||||
self.send_warning_alert_no_log(AlertDescription::CloseNotify)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
|
||||
async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
|
||||
let m = Message::build_alert(AlertLevel::Warning, desc);
|
||||
self.send_msg(m, self.record_layer.is_encrypting());
|
||||
self.send_msg(m, self.record_layer.is_encrypting()).await;
|
||||
}
|
||||
|
||||
pub(crate) fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
|
||||
@@ -1112,7 +1074,7 @@ pub(crate) trait State<ClientConnectionData>: Send + Sync {
|
||||
Err(Error::HandshakeNotComplete)
|
||||
}
|
||||
|
||||
fn perhaps_write_key_update(&mut self, _cx: &mut CommonState) {}
|
||||
async fn perhaps_write_key_update(&mut self, _cx: &mut CommonState) {}
|
||||
}
|
||||
|
||||
pub(crate) struct Context<'a> {
|
||||
|
||||
@@ -357,7 +357,7 @@ pub use crate::anchors::{OwnedTrustAnchor, RootCertStore};
|
||||
pub use crate::builder::{
|
||||
ConfigBuilder, WantsCipherSuites, WantsKxGroups, WantsVerifier, WantsVersions,
|
||||
};
|
||||
pub use crate::conn::{CommonState, ConnectionCommon, IoState, Reader, SideData, Writer};
|
||||
pub use crate::conn::{CommonState, ConnectionCommon, IoState, Reader, SideData};
|
||||
pub use crate::error::Error;
|
||||
pub use crate::key::{Certificate, PrivateKey};
|
||||
pub use crate::key_log::{KeyLog, NoKeyLog};
|
||||
@@ -393,7 +393,7 @@ pub mod client {
|
||||
pub use client_conn::ResolvesClientCert;
|
||||
pub use client_conn::ServerName;
|
||||
pub use client_conn::StoresClientSessions;
|
||||
pub use client_conn::{ClientConfig, ClientConnection, ClientConnectionData, WriteEarlyData};
|
||||
pub use client_conn::{ClientConfig, ClientConnection, ClientConnectionData};
|
||||
pub use handy::{ClientSessionMemoryCache, NoClientSessionStorage};
|
||||
|
||||
#[cfg(feature = "dangerous_configuration")]
|
||||
|
||||
@@ -104,7 +104,7 @@ pub trait ServerCertVerifier: Send + Sync {
|
||||
end_entity: &Certificate,
|
||||
intermediates: &[Certificate],
|
||||
server_name: &ServerName,
|
||||
scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
scts: &mut (dyn Iterator<Item = &[u8]> + Send),
|
||||
ocsp_response: &[u8],
|
||||
now: SystemTime,
|
||||
) -> Result<ServerCertVerified, Error>;
|
||||
@@ -295,7 +295,7 @@ impl ServerCertVerifier for WebPkiVerifier {
|
||||
end_entity: &Certificate,
|
||||
intermediates: &[Certificate],
|
||||
server_name: &ServerName,
|
||||
scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
scts: &mut (dyn Iterator<Item = &[u8]> + Send),
|
||||
ocsp_response: &[u8],
|
||||
now: SystemTime,
|
||||
) -> Result<ServerCertVerified, Error> {
|
||||
|
||||
Reference in New Issue
Block a user